# Torch jit trace test for quantizer

### 1. How to capture whole graph in quantizer ?

Quantizer capture the whole graph by using tracing in PyTorch. The tracing is an export method. It runs a model with example inputs, recording the operations performed on all the tensors. Quantizer use two different Pytorch API to get tracing graph. One is "_get_trace_graph" which is used to get graph from model without control flow. This internal API was designed earlier than "torch.jit.trace" for onnx exporting. The other one is "torch.jit.trace", it is used to get graph from model with control flow. Of course, the control flow part is scripted using "@script_if_tracing".Typically, this only requires a small refactor of the forward fuction to separate the control flow parts that need to be compiled.That does not means we fully support the torch script.For quantizer requirments, we should use tracing for the majority of logic, and use scripting only when necessary.

### 2. The problems related with tracing you should pay attention to 

##### 1. Dynamic Control flow

In [1]:
import torch
def f(x):
    return torch.nn.functional.relu(x) if x.sum() > 0 else torch.nn.functional.relu6(x)

traced_script = torch.jit.trace(f, torch.randn(3))
print(traced_script.inlined_graph)

graph(%x : Float(3, strides=[1], requires_grad=0, device=cpu)):
  %1 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::relu(%x) # /proj/rdi/staff/wluo/tools/anaconda3/envs/torch1.12/lib/python3.7/site-packages/torch/nn/functional.py:1457:0
  return (%1)



  This is separate from the ipykernel package so we can avoid doing imports until


In this example, the trace only keeps one branch of control flow which is depend on the concret inputs. If we truely want to preserve the control flow in the function, we can use the "@script_if_tracing".

In [2]:
import torch
@torch.jit.script_if_tracing
def f(x):
    return torch.nn.functional.relu(x) if x.sum() > 0 else torch.nn.functional.relu6(x)

traced_script = torch.jit.trace(f, torch.randn(3))
print(traced_script.inlined_graph)

graph(%x : Float(3, strides=[1], requires_grad=0, device=cpu)):
  %1 : Function = prim::Constant[name="f"]()
  %3 : int = prim::Constant[value=0]() # /tmp/ipykernel_28127/3830558629.py:4:52
  %4 : NoneType = prim::Constant()
  %5 : Tensor = aten::sum(%x, %4) # /tmp/ipykernel_28127/3830558629.py:4:42
  %6 : Tensor = aten::gt(%5, %3) # /tmp/ipykernel_28127/3830558629.py:4:42
  %7 : bool = aten::Bool(%6) # /tmp/ipykernel_28127/3830558629.py:4:42
  %8 : Tensor = prim::If(%7) # /tmp/ipykernel_28127/3830558629.py:4:11
    block0():
      %result.6 : Tensor = aten::relu(%x) # /proj/rdi/staff/wluo/tools/anaconda3/envs/torch1.12/lib/python3.7/site-packages/torch/nn/functional.py:1457:17
      -> (%result.6)
    block1():
      %result.3 : Tensor = aten::relu6(%x) # /proj/rdi/staff/wluo/tools/anaconda3/envs/torch1.12/lib/python3.7/site-packages/torch/nn/functional.py:1534:17
      -> (%result.3)
  return (%8)



##### 2. Freeze variables as constants

In [3]:
x = torch.rand(1)
y = torch.rand(2)
def f(x): return torch.arange(len(x))
traced_script = torch.jit.trace(f, x)
print(traced_script.code)
traced_script(y)

def f(x: Tensor) -> Tensor:
  _0 = torch.arange(1, dtype=None, layout=0, device=torch.device("cpu"), pin_memory=False)
  return _0



  This is separate from the ipykernel package so we can avoid doing imports until


tensor([0])

Intermediate computation results of a non-Tensor type (in this case, an int type) may be frozen as constants, using the value observed during tracing. This causes the trace to not generalize. we should use symbolic shapes instead.

In [4]:
import torch
x = torch.rand(1)
y = torch.rand(2)
def f(x): return torch.arange(x.size(0))
traced_script = torch.jit.trace(f, x)
print(traced_script.code)
traced_script(y)

def f(x: Tensor) -> Tensor:
  _0 = ops.prim.NumToTensor(torch.size(x, 0))
  _1 = torch.arange(annotate(number, _0), dtype=None, layout=0, device=torch.device("cpu"), pin_memory=False)
  return _1



tensor([0, 1])

##### 3. Freeze device

In [5]:
import torch
def f(x):
    return torch.as_tensor(x, device=x.device)
traced_script = torch.jit.trace(f, torch.randn(2))
print(traced_script.code)

def f(x: Tensor) -> Tensor:
  return torch.to(x, torch.device("cpu"), 6)



  This is separate from the ipykernel package so we can avoid doing imports until


The device attribute of input will be frozen during tracing.The trace script may not generalize to inputs on a different device. Such generalization is almost never needed, because deployment usually has a target device.

##### 4. Input/output format

Model's inputs/outputs have to be "Union[Tensor, Tuple[Tensor]]" to be traceable.The format requirement only applies to the outer-most model, so it's very easy to address. If the model uses richer formats such "Dict[str, tensor]", just create a simple wrapper around it that converts to/from Tuple[Tensor]. 

In [6]:
import torch
from typing import Dict
class RichFormatModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x:Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        y = {}
        y["a"] = torch.sqrt(x["a"])
        y["b"] = torch.square(x["b"])
        return y
input = {"a": torch.tensor(2.0), "b": torch.tensor(3.0)}
output = RichFormatModel()(input)
print(output)
trace_script = torch.jit.trace(RichFormatModel(), input)
print(trace_script.inlined_graph)

{'a': tensor(1.4142), 'b': tensor(9.)}


RuntimeError: Encountering a dict at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a `NamedTuple` instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior.

We can add wrappers to manually transform the input into a flattened input, and refactor the flattened output into RichFormatModel 's rich format output, suitable for tracing and downstream tasks. Hopefully we can automate the format conversion in the near future.

In [7]:
from typing import Tuple

class RichFormatWrapper(torch.nn.Module):
    def __init__(self, trace_model):
        super().__init__()
        self.trace_model = trace_model
        
    def forward(self, x:Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
        flatten_x = x["a"], x["b"]
        flatten_outputs = self.trace_model(*flatten_x)
        return {"a": flatten_outputs[0], "b": flatten_outputs[1]}
    
class TraceWrapper(torch.nn.Module):
    def __init__(self, origin_model):
        super().__init__()
        self.origin_model = origin_model
    
    def forward(self, *x: Tuple[torch.tensor]) -> Tuple[torch.tensor]:
        dict_inputs = {"a": x[0], "b": x[1]}
        dict_outputs = self.origin_model(dict_inputs)
        flatten_outputs = dict_outputs["a"], dict_outputs["b"]
        return flatten_outputs
    
trace_model = TraceWrapper(RichFormatModel())
flatten_inputs = input["a"], input["b"]
trace_script = torch.jit.trace(trace_model, flatten_inputs)
new_model = RichFormatWrapper(trace_script)
outputs = new_model(input)
print(outputs)

{'a': tensor(1.4142), 'b': tensor(9.)}


### 3. How to pass jit test ?

step 1:  Do torch.jit.trace test, refer to https://pytorch.org/docs/stable/generated/torch.jit.trace.html?highlight=torch+jit+trace#torch.jit.trace. If you encounter the error "TracingCheckError: Tracing failed sanity checks!". This means that if your model trace twice with the same inputs, it will get a different graph. You can set "check_trace=False" to walk around it.

step 2: Use the trace script for evaluation testing to ensure that the trace script behaves correctly. If you have any problems with the evaluation test. The reason is that the traced script may depend on the traced input, you can try to trace with real data instead of dummy data. If the problem persists, check and modify your model to be independent of specific inputs.