# Scripting a function

In [25]:
import torch

@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r

print(type(foo))  # torch.jit.ScriptFunction

# See the compiled graph as Python code
print(foo.code)

# Call the function using the TorchScript interpreter
foo(torch.ones(2, 2), torch.ones(2, 2))

<class 'torch.jit.ScriptFunction'>
def foo(x: Tensor,
    y: Tensor) -> Tensor:
  _0 = bool(torch.gt(torch.max(x), torch.max(y)))
  if _0:
    r = x
  else:
    r = y
  return r



tensor([[1., 1.],
        [1., 1.]])

# Scripting a function using example_inputs

In [26]:
import torch

def test_sum(a, b):
    return a + b

# Annotate the arguments to be int
scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)])

print(type(scripted_fn))  # torch.jit.ScriptFunction

# See the compiled graph as Python code
print(scripted_fn.code)

# Call the function using the TorchScript interpreter
a = torch.as_tensor(20)
b = torch.as_tensor(100)
scripted_fn(a, b)

<class 'torch.jit.ScriptFunction'>
def test_sum(a: Tensor,
    b: Tensor) -> Tensor:
  return torch.add(a, b)





tensor(120)

# Example (scripting a simple module with a Parameter):

Scripting an nn.Module by default will compile the forward method and recursively compile any methods, submodules, and functions called by forward.

In [27]:
class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super().__init__()
        # This parameter will be copied to the new ScriptModule
        self.weight = torch.nn.Parameter(torch.rand(N, M))

        # When this submodule is used, it will be compiled
        self.linear = torch.nn.Linear(N, M)

    def forward(self, input):
        output = self.weight.mv(input)

        # This calls the `forward` method of the `nn.Linear` module, which will
        # cause the `self.linear` submodule to be compiled to a `ScriptModule` here
        output = self.linear(output)
        return output

scripted_module = torch.jit.script(MyModule(2, 3))

print(scripted_module.code)

def forward(self,
    input: Tensor) -> Tensor:
  weight = self.weight
  output = torch.mv(weight, input)
  linear = self.linear
  return (linear).forward(output, )



# Example (scripting a module with traced submodules):

In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        # torch.jit.trace produces a ScriptModule's conv1 and conv2
        self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
        self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))

    def forward(self, input):
        input = F.relu(self.conv1(input))
        input = F.relu(self.conv2(input))
        return input

scripted_module = torch.jit.script(MyModule())

print(scripted_module.code)

def forward(self,
    input: Tensor) -> Tensor:
  conv1 = self.conv1
  input0 = __torch__.torch.nn.functional.relu((conv1).forward(input, ), False, )
  conv2 = self.conv2
  input1 = __torch__.torch.nn.functional.relu((conv2).forward(input0, ), False, )
  return input1



# Example (an exported and ignored method in a module):

To compile a method other than forward (and recursively compile anything it calls), add the @torch.jit.export decorator to the method. To opt out of compilation use @torch.jit.ignore or @torch.jit.unused.

In [29]:
# import torch
import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()

    @torch.jit.export
    def some_entry_point(self, input):
        return input + 10

    @torch.jit.ignore
    def python_only_fn(self, input):
        # This function won't be compiled, so any
        # Python APIs can be used
        import pdb
        pdb.set_trace()

    def forward(self, input):
        if self.training:
            self.python_only_fn(input)
        return input * 99

scripted_module = torch.jit.script(MyModule())
print(scripted_module.some_entry_point(torch.randn(2, 2)))
print(scripted_module(torch.randn(2, 2)))

tensor([[ 9.7463, 10.6978],
        [ 9.3826, 11.3236]])
--Return--
None
> [0;32m/tmp/ipykernel_754881/3153929141.py[0m(17)[0;36mpython_only_fn[0;34m()[0m
[0;32m     15 [0;31m        [0;31m# Python APIs can be used[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     16 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 17 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     18 [0;31m[0;34m[0m[0m
[0m[0;32m     19 [0;31m    [0;32mdef[0m [0mforward[0m[0;34m([0m[0mself[0m[0;34m,[0m [0minput[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m
tensor([[ 114.5182,   17.4755],
        [-275.7512, -146.5401]])


# Example ( Annotating forward of nn.Module using example_inputs):

the following examples do not work ? 

In [30]:
#import torch
#import torch.nn as nn
from typing import NamedTuple, List

class MyModule(NamedTuple):
    result: List[int]

class TestNNModule(torch.nn.Module):
    def forward(self, a) -> MyModule:
        result = MyModule(result=a).result
        return result

pdt_model = TestNNModule()

# Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forward
scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })

# Run the scripted_model with actual inputs
print(scripted_model([20]))

RuntimeError: Can't redefine NamedTuple: __torch__.MyModule