In [16]:
import torch
import torchvision

# use `torch.jit.trace` if the model doesn't use control flows

In [17]:
# An instance of your model.
model = torchvision.models.resnet18()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224) # NCHW

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

In [18]:
type(traced_script_module)

torch.jit.TopLevelTracedModule

In [19]:
output = traced_script_module(torch.ones(1, 3, 224, 224))
output.shape

torch.Size([1, 1000])

In [20]:
output[0, :5]

tensor([-0.4811, -0.1557,  0.7336, -0.0878,  0.9928], grad_fn=<SliceBackward>)

# subclassing `torch.jit.ScriptModule` - if the Model uses control flow

In [8]:
import torch

class MyModule(torch.jit.ScriptModule):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    @torch.jit.script_method
    def forward(self, input):
        if bool(input.sum() > 0):
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

my_script_module = MyModule(2, 3)

In [13]:
example2 = torch.rand(3)

output2 = my_script_module(example2)

output2

tensor([1.3336, 0.6524], grad_fn=<MvBackward>)

# serialize to a file

In [21]:
traced_script_module.save('traced_script_module.pt')