Skip to content

Latest commit

 

History

History

graph

Introduction

简体中文

Graph tracer for PyTorch.

Background

It is always hard to get the connections between nodes within a computation graph without calling torch.jit.trace. However, if we call the JIT ops, some of nodes in the graph may be removed, which makes it difficult to restore the graph in Python. By calling our tracer, you can get the connections between the nodes while the model is still runnable with Python. What's more, it can be used as a code generator that organizes the codes of your model into a single script.

Sample usage

import torch
import torchvision

from tinynn.graph.tracer import model_tracer, trace

with model_tracer():
    # Prapare the model
    model = torchvision.models.alexnet()
    model.eval()

    # Provide a viable input for the model
    dummy_input = torch.rand((1, 3, 224, 224))

    # After tracing the model, we will get a TraceGraph object
    graph = trace(model, dummy_input)

    # We can use it to generate the code for the original model
    graph.generate_code(f'my_alexnet.py', f'my_alexnet.pth', 'Alexnet')

Quantization

According to the official PyTorch tutorial, you'll have to do the following things to get a quantized model running on devices.

  1. Inserting the QuantStub / DeQuantStub nodes after / before the input / output nodes.
  2. Rewrite the unsupported OPs (e.g. torch.add(x, y) -> FloatFunctional.add(x, y))
  3. Fuse the modules (e.g. Conv + BatchNorm + ReLU -> ConvBnReLU)
  4. Convert the model to the quantized version
  5. JIT Tracing and seralize it to a TorchScript model file
  6. Running the inference through PyTorch Mobile

Step 1-3 require a lot of manual work, which are pretty cumbersome and error-prone. Therefore, we have come up with the idea to write an automatic quantization preparation tool.

What's more

  1. The model to be traced could either be instantiated inside or outside the with-block.
  2. You may trace multiple models in one with-block.
  3. It is supported to have runtime-defined constants. If the size is too large, these constants will be transformed to parameters.

Limitations

  1. Like torch.jit.trace, if you trace such models with control-flow ops, you may silently get incorrect results on subsequent invocations of the model.
  2. Only the flow of PyTorch tensors is tracked. Other variables (e.g. numpy arrays, numbers and strings) won't be tracked and will be treated as constants.
  3. Only those parts of the tensor property are tracked. For example, if you call .data or .shape of a tensor, the returned values will be joined in the computation graph. Below are the properties of a tensor that will be tracked.
  • .data
  • .shape
  • .device
  • .dtype
  1. It is not supported to call numel on a torch.Size object generated by calling .size() or .shape on a tensor. Please use torch.prod instead.