In [1]:
import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)
        self.linear2 = torch.nn.Linear(5, 5)

    def forward(self, x):
        out = self.linear(x + self.param)
        return self.linear2(x).clamp(min=0.0, max=1.0)

module = MyModule()

from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)


In [2]:
print(symbolic_traced.graph)

graph():
    %x : [#users=2] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [#users=0] = call_module[target=linear](args = (%add,), kwargs = {})
    %linear2 : [#users=1] = call_module[target=linear2](args = (%x,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear2,), kwargs = {min: 0.0, max: 1.0})
    return clamp


In [3]:
g = symbolic_traced.graph

In [4]:
g.print_tabular()

opcode         name     target                   args        kwargs
-------------  -------  -----------------------  ----------  ------------------------
placeholder    x        x                        ()          {}
get_attr       param    param                    ()          {}
call_function  add      <built-in function add>  (x, param)  {}
call_module    linear   linear                   (add,)      {}
call_module    linear2  linear2                  (x,)        {}
call_method    clamp    clamp                    (linear2,)  {'min': 0.0, 'max': 1.0}
output         output   output                   (clamp,)    {}


In [5]:
for node in g.nodes:
    if node.op in ["call_module", "call_method"]:
        print(node.op, node.target, node.args)

call_module linear (add,)
call_module linear2 (x,)
call_method clamp (linear2,)


In [6]:
print(symbolic_traced.code)


def forward(self, x):
    param = self.param
    add = x + param;  param = None
    linear = self.linear(add);  add = None
    linear2 = self.linear2(x);  x = None
    clamp = linear2.clamp(min = 0.0, max = 1.0);  linear2 = None
    return clamp
    


In [7]:
type(symbolic_traced)

torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl

In [8]:
from torch.fx.passes.shape_prop import ShapeProp

In [9]:
class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)
    def forward(self, x):
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred
N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
model = TwoLayerNet(D_in, H, D_out)
gm = torch.fx.symbolic_trace(model)
sample_input = torch.randn(50, D_in)
ShapeProp(gm).propagate(sample_input)


tensor([[-2.4654e-01,  1.7194e-02,  6.8955e-02,  2.2021e-01,  4.4075e-02,
          8.3911e-02,  7.3841e-02, -5.7304e-02, -2.3597e-01,  2.6656e-01],
        [-3.1227e-01, -1.0711e-01,  2.1948e-01,  1.0238e-01, -5.7619e-02,
         -1.2323e-01, -2.2494e-01, -4.0803e-01, -6.5411e-02,  4.4413e-01],
        [-2.6080e-01,  3.2992e-01,  7.2437e-02, -2.4618e-01, -1.4985e-01,
         -2.4337e-01, -3.2790e-02, -4.7158e-01, -4.2521e-01,  1.9706e-01],
        [-4.4225e-01, -2.7483e-02, -3.0650e-01,  1.6262e-01, -4.4571e-02,
         -1.6837e-01,  1.8660e-01,  3.1474e-02, -4.1161e-01,  4.4892e-01],
        [-2.0672e-01, -1.3015e-02, -2.0443e-02,  2.3464e-01,  1.6306e-01,
         -8.4241e-02,  1.2211e-01, -1.4450e-01, -1.1165e-01,  7.1051e-01],
        [-1.0010e+00, -1.3849e-01,  2.3965e-01,  1.4083e-02,  4.2961e-01,
         -1.9452e-01,  2.4543e-01,  1.1455e-01, -2.1688e-01,  3.2175e-01],
        [-6.2559e-01,  3.8118e-01, -2.0298e-02,  1.8196e-01,  9.5967e-02,
         -1.0640e-01,  8.4253e-0

In [10]:
for node in gm.graph.nodes:
    print(node.name, node.meta['tensor_meta'].dtype,
        node.meta['tensor_meta'].shape)

x torch.float32 torch.Size([50, 1000])
linear1 torch.float32 torch.Size([50, 100])
clamp torch.float32 torch.Size([50, 100])
linear2 torch.float32 torch.Size([50, 10])
output torch.float32 torch.Size([50, 10])


In [11]:
from torchvision import models

In [12]:
net = models.resnet18()

In [19]:
gm = torch.fx.symbolic_trace(net)
sample_input = torch.randn(1, 3 ,224, 224)
ShapeProp(gm).propagate(sample_input)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


tensor([[-3.3239e-02, -7.5997e-01, -1.3424e-01, -1.2014e-01, -4.2429e-02,
          5.7649e-02,  5.2269e-01, -7.8221e-01, -3.5845e-01, -2.2237e-01,
         -1.0024e-01,  5.4364e-01, -1.8134e-01, -8.0633e-01,  4.2117e-01,
          3.8587e-01, -3.7670e-01,  7.7968e-01,  3.3259e-01,  1.6050e-01,
          1.1636e-01, -2.6960e-01,  2.0935e-01,  1.1797e-02, -3.2981e-01,
         -8.8132e-02,  8.8123e-02,  6.5380e-01, -2.5806e-01,  1.6193e-01,
         -3.9047e-01, -4.8978e-01, -1.4013e-01, -1.8331e-01, -4.8718e-02,
          6.3806e-02,  6.9854e-01, -9.1238e-01,  2.4740e-01, -6.3716e-01,
          6.2757e-01, -7.6974e-01, -8.5881e-01, -7.0892e-02,  3.4040e-01,
          3.7809e-01, -6.4334e-01, -4.5084e-01, -3.6080e-01,  4.0448e-01,
         -7.0230e-02,  2.6103e-01, -3.8110e-02,  3.8979e-01,  3.4592e-01,
         -3.2595e-01,  3.2237e-01,  5.0795e-01,  6.2589e-01,  4.0218e-01,
         -3.5472e-01, -3.0060e-01, -6.5672e-01, -3.1811e-01,  2.5244e-01,
          8.0378e-02, -8.3943e-01,  1.

In [20]:
gm.graph.print_tabular()

opcode         name                   target                                                   args                                   kwargs
-------------  ---------------------  -------------------------------------------------------  -------------------------------------  --------
placeholder    x                      x                                                        ()                                     {}
call_module    conv1                  conv1                                                    (x,)                                   {}
call_module    bn1                    bn1                                                      (conv1,)                               {}
call_module    relu                   relu                                                     (bn1,)                                 {}
call_module    maxpool                maxpool                                                  (relu,)                                {}
call_module    layer1_0_conv1  

In [22]:
gm.graph.print_tabular()

opcode         name                   target                                                   args                                   kwargs
-------------  ---------------------  -------------------------------------------------------  -------------------------------------  --------
placeholder    x                      x                                                        ()                                     {}
call_module    conv1                  conv1                                                    (x,)                                   {}
call_module    bn1                    bn1                                                      (conv1,)                               {}
call_module    relu                   relu                                                     (bn1,)                                 {}
call_module    maxpool                maxpool                                                  (relu,)                                {}
call_module    layer1_0_conv1  

In [46]:
for node in gm.graph.nodes:
    # print(f"{node.target},\t{node.op},\t{node.meta['tensor_meta'].dtype},\t{node.meta['tensor_meta'].shape}")
    if node.op == "call_module":
        print(f"{node.target},\t{node.op},\t params:{net.state_dict()[node.target + '.weight'].shape} out:{node.meta['tensor_meta'].shape}")
    


conv1,	call_module,	 params:torch.Size([64, 3, 7, 7]) out:torch.Size([1, 64, 112, 112])
bn1,	call_module,	 params:torch.Size([64]) out:torch.Size([1, 64, 112, 112])


KeyError: 'relu.weight'

In [40]:
net.state_dict()['layer1.1.conv2.weight'].shape

torch.Size([64, 64, 3, 3])

In [47]:
node.meta['tensor_meta']

TensorMetadata(shape=torch.Size([1, 64, 112, 112]), dtype=torch.float32, stride=(802816, 12544, 112, 1), memory_format=torch.contiguous_format, is_quantized=False, qscheme=None, q_scale=None, q_zero_point=None)

In [48]:
class ShapeProp(torch.fx.Interpreter):
    def run_node(self, n : Node) -> Any:
        result = super().run_node(n)
        
        found_tensor = False
        def extract_tensor_meta(obj):
            if isinstance(obj, torch.Tensor):
                nonlocal found_tensor
                found_tensor = True
                return extract_tensor_metadata(obj)
            else:
                return obj

        meta = map_aggregate(result, extract_tensor_meta)
        if found_tensor:
            n.meta['tensor_meta'] = meta

        n.meta['type'] = type(result)
        return result

    def propagate(self, *args):
        """
        Run `module` via interpretation and return the result and
        record the shape and type of each node.
        Args:
            *args (Tensor): the sample input.
        Returns:
            Any: The value returned from executing the Module
        """
        return super().run(*args)

NameError: name 'TensorMetadata' is not defined