## TorchScript 

### JIT Tracing 

In [1]:
import torch
import torchvision

In [4]:
model = torchvision.models.alexnet(weights=None)
# trace_model is now a GRAPH, not EAGER
model.eval()
example_input = torch.rand(1,3,224,224)
traced_model = torch.jit.trace(model, example_input)


In [5]:
# Save the traced_model
traced_model.save("alexnet_traced.pt")

In [7]:
print(traced_model)

AlexNet(
  original_name=AlexNet
  (features): Sequential(
    original_name=Sequential
    (0): Conv2d(original_name=Conv2d)
    (1): ReLU(original_name=ReLU)
    (2): MaxPool2d(original_name=MaxPool2d)
    (3): Conv2d(original_name=Conv2d)
    (4): ReLU(original_name=ReLU)
    (5): MaxPool2d(original_name=MaxPool2d)
    (6): Conv2d(original_name=Conv2d)
    (7): ReLU(original_name=ReLU)
    (8): Conv2d(original_name=Conv2d)
    (9): ReLU(original_name=ReLU)
    (10): Conv2d(original_name=Conv2d)
    (11): ReLU(original_name=ReLU)
    (12): MaxPool2d(original_name=MaxPool2d)
  )
  (avgpool): AdaptiveAvgPool2d(original_name=AdaptiveAvgPool2d)
  (classifier): Sequential(
    original_name=Sequential
    (0): Dropout(original_name=Dropout)
    (1): Linear(original_name=Linear)
    (2): ReLU(original_name=ReLU)
    (3): Dropout(original_name=Dropout)
    (4): Linear(original_name=Linear)
    (5): ReLU(original_name=ReLU)
    (6): Linear(original_name=Linear)
  )
)


In [8]:
print(traced_model.code)

def forward(self,
    x: Tensor) -> Tensor:
  classifier = self.classifier
  avgpool = self.avgpool
  features = self.features
  _0 = (avgpool).forward((features).forward(x, ), )
  input = torch.flatten(_0, 1)
  return (classifier).forward(input, )



In [6]:
# Load model and run it
loaded_model = torch.jit.load("alexnet_traced.pt")
out = loaded_model(torch.rand(1,3,224,224))
print(out.shape)

torch.Size([1, 1000])


### Scripting

In [4]:
import torch
import torch.nn as nn

@torch.jit.script
def example(x, y):
    if x.min() > y.min():
        r = x 
    else:
        r = y 
    return r 
    

In [16]:
# Note: inherit from torch.jit.ScriptModule and "torch.jit.script_method" can only be used when we have 1 jit.trace() only
# here we have 2 layers with trace()
class FeaturesCNNNet(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2)
            )
        features.eval()
        self.features = torch.jit.trace(features.eval(), torch.rand(1,3,224,224))
        # Note: we have dropout, hence must be eval() first
        classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256*6*6, 4096),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Linear(4096, num_classes)
        )
        classifier.eval()
        self.classifier = torch.jit.trace(classifier, torch.rand(1,256*6*6))
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x
            
model = FeaturesCNNNet()
x = torch.rand(1, 3, 224, 224)
out = model(x)
print("Output shape: ", out.shape)
    

Output shape:  torch.Size([1, 2])


In [18]:
# save model
torch.jit.script(model.eval()).save("full.pt")

In [19]:
# load model
loaded_model = torch.jit.load("full.pt")

### Weirdness

In [21]:
# reason: TorchScript assumes default=Tensor
@torch.jit.script
def add_int(x, y):
    return x + y
print(add_int.code)

def add_int(x: Tensor,
    y: Tensor) -> Tensor:
  return torch.add(x, y)



In [22]:
# solution: we must use "type decorator"
@torch.jit.script
def add_int(x:int, y:int) -> int:
    return x+y
print(add_int.code)

def add_int(x: int,
    y: int) -> int:
  return torch.add(x, y)

