In [1]:
!pip install torch==1.12.1
!pip install torchvision==0.13.1
!pip install Pillow==9.3.0

[0m

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import numpy as np
from PIL import Image

In [3]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.cn1 = nn.Conv2d(1, 16, 3, 1)
        self.cn2 = nn.Conv2d(16, 32, 3, 1)
        self.dp1 = nn.Dropout2d(0.10)
        self.dp2 = nn.Dropout2d(0.25)
        self.fc1 = nn.Linear(4608, 64) # 4608 is basically 12 X 12 X 32
        self.fc2 = nn.Linear(64, 10)
 
    def forward(self, x):
        x = self.cn1(x)
        x = F.relu(x)
        x = self.cn2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dp1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dp2(x)
        x = self.fc2(x)
        op = F.log_softmax(x, dim=1)
        return op
    
model = ConvNet()

In [4]:
PATH_TO_MODEL = "./convnet.pth"
model.load_state_dict(torch.load(PATH_TO_MODEL, map_location="cpu"))

<All keys matched successfully>

In [5]:
model.eval()
for p in model.parameters():
    p.requires_grad_(False)

In [6]:
scripted_model = torch.jit.script(model)

In [7]:
scripted_model.graph

graph(%self : __torch__.ConvNet,
      %x.1 : Tensor):
  %51 : Function = prim::Constant[name="log_softmax"]()
  %49 : int = prim::Constant[value=3]()
  %33 : int = prim::Constant[value=-1]()
  %26 : Function = prim::Constant[name="_max_pool2d"]()
  %20 : int = prim::Constant[value=0]()
  %19 : NoneType = prim::Constant()
  %7 : Function = prim::Constant[name="relu"]()
  %6 : bool = prim::Constant[value=0]()
  %17 : int = prim::Constant[value=2]() # /var/folders/gs/mjlw0j210yz02z4yrv9gshdm0000gq/T/ipykernel_13610/2721400238.py:16:28
  %32 : int = prim::Constant[value=1]() # /var/folders/gs/mjlw0j210yz02z4yrv9gshdm0000gq/T/ipykernel_13610/2721400238.py:18:29
  %cn1 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="cn1"](%self)
  %x.5 : Tensor = prim::CallMethod[name="forward"](%cn1, %x.1) # /var/folders/gs/mjlw0j210yz02z4yrv9gshdm0000gq/T/ipykernel_13610/2721400238.py:12:12
  %x.9 : Tensor = prim::CallFunction(%7, %x.5, %6) # /var/folders/gs/mjlw0j210yz02z4yrv9gshdm0000gq/T

In [8]:
print(scripted_model.code)

def forward(self,
    x: Tensor) -> Tensor:
  _0 = __torch__.torch.nn.functional._max_pool2d
  _1 = __torch__.torch.nn.functional.log_softmax
  cn1 = self.cn1
  x0 = (cn1).forward(x, )
  x1 = __torch__.torch.nn.functional.relu(x0, False, )
  cn2 = self.cn2
  x2 = (cn2).forward(x1, )
  x3 = __torch__.torch.nn.functional.relu(x2, False, )
  x4 = _0(x3, [2, 2], None, [0, 0], [1, 1], False, False, )
  dp1 = self.dp1
  x5 = (dp1).forward(x4, )
  x6 = torch.flatten(x5, 1)
  fc1 = self.fc1
  x7 = (fc1).forward(x6, )
  x8 = __torch__.torch.nn.functional.relu(x7, False, )
  dp2 = self.dp2
  x9 = (dp2).forward(x8, )
  fc2 = self.fc2
  x10 = (fc2).forward(x9, )
  return _1(x10, 1, 3, None, )



In [9]:
torch.jit.save(scripted_model, 'scripted_convnet.pt')

In [10]:
loaded_scripted_model = torch.jit.load('scripted_convnet.pt')

In [11]:
image = Image.open("./digit_image.jpg")

In [12]:
def image_to_tensor(image):
    gray_image = transforms.functional.to_grayscale(image)
    resized_image = transforms.functional.resize(gray_image, (28, 28))
    input_image_tensor = transforms.functional.to_tensor(resized_image)
    input_image_tensor_norm = transforms.functional.normalize(input_image_tensor, (0.1302,), (0.3069,))
    return input_image_tensor_norm

In [13]:
input_tensor = image_to_tensor(image)

In [14]:
loaded_scripted_model(input_tensor.unsqueeze(0))



tensor([[-1.0458e+01, -1.3929e+01, -2.5733e-03, -8.8133e+00, -1.0267e+01,
         -1.5833e+01, -1.2593e+01, -1.3940e+01, -6.0533e+00, -1.2960e+01]])

In [15]:
model(input_tensor.unsqueeze(0))



tensor([[-1.0458e+01, -1.3929e+01, -2.5733e-03, -8.8133e+00, -1.0267e+01,
         -1.5833e+01, -1.2593e+01, -1.3940e+01, -6.0533e+00, -1.2960e+01]])