In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import numpy as np
from PIL import Image

In [2]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.cn1 = nn.Conv2d(1, 16, 3, 1)
        self.cn2 = nn.Conv2d(16, 32, 3, 1)
        self.dp1 = nn.Dropout2d(0.10)
        self.dp2 = nn.Dropout2d(0.25)
        self.fc1 = nn.Linear(4608, 64) # 4608 is basically 12 X 12 X 32
        self.fc2 = nn.Linear(64, 10)
 
    def forward(self, x):
        x = self.cn1(x)
        x = F.relu(x)
        x = self.cn2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dp1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dp2(x)
        x = self.fc2(x)
        op = F.log_softmax(x, dim=1)
        return op
    
model = ConvNet()

In [3]:
PATH_TO_MODEL = "./convnet.pth"
model.load_state_dict(torch.load(PATH_TO_MODEL, map_location="cpu"))

<All keys matched successfully>

In [4]:
model.eval()
for p in model.parameters():
    p.requires_grad_(False)

In [5]:
scripted_model = torch.jit.script(model)

In [6]:
scripted_model.graph

graph(%self : __torch__.ConvNet,
      %x.1 : Tensor):
  %51 : Function = prim::Constant[name="log_softmax"]()
  %49 : int = prim::Constant[value=3]()
  %40 : Function = prim::Constant[name="relu"]()
  %33 : int = prim::Constant[value=-1]()
  %26 : Function = prim::Constant[name="_max_pool2d"]()
  %20 : int = prim::Constant[value=0]()
  %19 : None = prim::Constant()
  %14 : Function = prim::Constant[name="relu"]()
  %7 : Function = prim::Constant[name="relu"]()
  %6 : bool = prim::Constant[value=0]()
  %17 : int = prim::Constant[value=2]() # <ipython-input-2-936a1c5cab85>:16:28
  %32 : int = prim::Constant[value=1]() # <ipython-input-2-936a1c5cab85>:18:29
  %2 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="cn1"](%self)
  %x.3 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) # <ipython-input-2-936a1c5cab85>:12:12
  %x.5 : Tensor = prim::CallFunction(%7, %x.3, %6) # <ipython-input-2-936a1c5cab85>:13:12
  %9 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d =

In [7]:
print(scripted_model.code)

def forward(self,
    x: Tensor) -> Tensor:
  _0 = __torch__.torch.nn.functional.___torch_mangle_12.relu
  _1 = __torch__.torch.nn.functional._max_pool2d
  _2 = __torch__.torch.nn.functional.___torch_mangle_13.relu
  _3 = __torch__.torch.nn.functional.log_softmax
  x0 = (self.cn1).forward(x, )
  x1 = __torch__.torch.nn.functional.relu(x0, False, )
  x2 = (self.cn2).forward(x1, )
  x3 = _0(x2, False, )
  x4 = _1(x3, [2, 2], None, [0, 0], [1, 1], False, False, )
  x5 = (self.dp1).forward(x4, )
  x6 = torch.flatten(x5, 1, -1)
  x7 = (self.fc1).forward(x6, )
  x8 = _2(x7, False, )
  x9 = (self.dp2).forward(x8, )
  x10 = (self.fc2).forward(x9, )
  return _3(x10, 1, 3, None, )



In [8]:
torch.jit.save(scripted_model, 'scripted_convnet.pt')

In [9]:
loaded_scripted_model = torch.jit.load('scripted_convnet.pt')

In [10]:
image = Image.open("./digit_image.jpg")

In [11]:
def image_to_tensor(image):
    gray_image = transforms.functional.to_grayscale(image)
    resized_image = transforms.functional.resize(gray_image, (28, 28))
    input_image_tensor = transforms.functional.to_tensor(resized_image)
    input_image_tensor_norm = transforms.functional.normalize(input_image_tensor, (0.1302,), (0.3069,))
    return input_image_tensor_norm

In [12]:
input_tensor = image_to_tensor(image)

In [13]:
loaded_scripted_model(input_tensor.unsqueeze(0))

tensor([[-9.3505e+00, -1.2089e+01, -2.2391e-03, -8.9248e+00, -9.8197e+00,
         -1.3350e+01, -9.0460e+00, -1.4492e+01, -6.3023e+00, -1.2283e+01]])

In [14]:
model(input_tensor.unsqueeze(0))

tensor([[-9.3505e+00, -1.2089e+01, -2.2391e-03, -8.9248e+00, -9.8197e+00,
         -1.3350e+01, -9.0460e+00, -1.4492e+01, -6.3023e+00, -1.2283e+01]])