In [1]:
import torch
import torchvision

class MyScriptModule(torch.nn.Module):
    def __init__(self):
        super(MyScriptModule, self).__init__()
        self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
                                        .resize_(1, 3, 1, 1))
        self.resnet = torch.jit.trace(torchvision.models.resnet18(),
                                      torch.rand(1, 3, 224, 224))

    def forward(self, input):
        return self.resnet(input - self.means)

my_script_module = torch.jit.script(MyScriptModule())

In [6]:
print(my_script_module)
print(type(my_script_module))


ScriptModule(
  original_name=MyScriptModule
  (resnet): TracedModule[ResNet](
    original_name=ResNet
    (conv1): TracedModule[Conv2d](original_name=Conv2d)
    (bn1): TracedModule[BatchNorm2d](original_name=BatchNorm2d)
    (relu): TracedModule[ReLU](original_name=ReLU)
    (maxpool): TracedModule[MaxPool2d](original_name=MaxPool2d)
    (layer1): TracedModule[Sequential](
      original_name=Sequential
      (0): TracedModule[BasicBlock](
        original_name=BasicBlock
        (conv1): TracedModule[Conv2d](original_name=Conv2d)
        (bn1): TracedModule[BatchNorm2d](original_name=BatchNorm2d)
        (relu): TracedModule[ReLU](original_name=ReLU)
        (conv2): TracedModule[Conv2d](original_name=Conv2d)
        (bn2): TracedModule[BatchNorm2d](original_name=BatchNorm2d)
      )
      (1): TracedModule[BasicBlock](
        original_name=BasicBlock
        (conv1): TracedModule[Conv2d](original_name=Conv2d)
        (bn1): TracedModule[BatchNorm2d](original_name=BatchNorm2d)
   

In [4]:
@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(type(foo))
print(foo.code)

<class 'torch._C.Function'>
def foo(len: int) -> Tensor:
  rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
  rv0 = rv
  for i in range(len):
    if torch.lt(i, 10):
      rv1 = torch.sub(rv0, 1., 1)
    else:
      rv1 = torch.add(rv0, 1., 1)
    rv0 = rv1
  return rv0



In [7]:
print(foo.graph)

graph(%len.1 : int):
  %20 : int = prim::Constant[value=1]()
  %13 : bool = prim::Constant[value=1]() # <ipython-input-4-01a58e79a588>:5:4
  %5 : None = prim::Constant()
  %1 : int = prim::Constant[value=3]() # <ipython-input-4-01a58e79a588>:4:21
  %2 : int = prim::Constant[value=4]() # <ipython-input-4-01a58e79a588>:4:24
  %16 : int = prim::Constant[value=10]() # <ipython-input-4-01a58e79a588>:6:15
  %19 : float = prim::Constant[value=1]() # <ipython-input-4-01a58e79a588>:7:22
  %4 : int[] = prim::ListConstruct(%1, %2)
  %rv.1 : Tensor = aten::zeros(%4, %5, %5, %5, %5) # <ipython-input-4-01a58e79a588>:4:9
  %rv : Tensor = prim::Loop(%len.1, %13, %rv.1) # <ipython-input-4-01a58e79a588>:5:4
    block0(%i.1 : int, %rv.14 : Tensor):
      %17 : bool = aten::lt(%i.1, %16) # <ipython-input-4-01a58e79a588>:6:11
      %rv.13 : Tensor = prim::If(%17) # <ipython-input-4-01a58e79a588>:6:8
        block0():
          %rv.3 : Tensor = aten::sub(%rv.14, %19, %20) # <ipython-input-4-01a58e79a588>:7:

In [8]:
def fill_row_zero(x):
    x[0] = torch.rand(*x.shape[1:2])
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

graph(%x : Float(3, 4)):
  %4 : int = prim::Constant[value=1]() # <ipython-input-8-9083ec5950df>:2:0
  %5 : int = aten::size(%x, %4) # <ipython-input-8-9083ec5950df>:2:0
  %6 : Long() = prim::NumToTensor(%5)
  %7 : int = aten::Int(%6)
  %8 : int[] = prim::ListConstruct(%7)
  %9 : int = prim::Constant[value=6]() # <ipython-input-8-9083ec5950df>:2:0
  %10 : int = prim::Constant[value=0]() # <ipython-input-8-9083ec5950df>:2:0
  %11 : Device = prim::Constant[value="cpu"]() # <ipython-input-8-9083ec5950df>:2:0
  %12 : bool = prim::Constant[value=0]() # <ipython-input-8-9083ec5950df>:2:0
  %13 : Float(4) = aten::rand(%8, %9, %10, %11, %12) # <ipython-input-8-9083ec5950df>:2:0
  %14 : int = prim::Constant[value=0]() # <ipython-input-8-9083ec5950df>:2:0
  %15 : int = prim::Constant[value=0]() # <ipython-input-8-9083ec5950df>:2:0
  %16 : Float(4) = aten::select(%x, %14, %15) # <ipython-input-8-9083ec5950df>:2:0
  %17 : int = prim::Constant[value=4]() # <ipython-input-8-9083ec5950df>:2:0
  %18 :

	%13 : Float(4) = aten::rand(%8, %9, %10, %11, %12) # <ipython-input-8-9083ec5950df>:2:0
This may cause errors in trace checking. To disable trace checking, pass check_trace=False to torch.jit.trace()
  _check_trace([example_inputs], func, traced, check_tolerance, _force_outplace, False, _module_class)
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 2] (0.6968618035316467 vs. 0.20274019241333008) and 3 other locations (33.00%)
  _check_trace([example_inputs], func, traced, check_tolerance, _force_outplace, False, _module_class)
