References
1. https://hackernoon.com/how-to-run-machine-learning-models-in-the-browser-using-onnx
2. https://github.com/elliotwaite/pytorch-to-javascript-with-onnx-js/tree/master/full_demo

In [2]:
# !pip install onnx onnxruntime

In [1]:
import torch
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
class Net_v2(nn.Module):
    def __init__(self):
        super(Net_v2, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, 
                               out_channels=32, 
                               kernel_size=3, 
                               stride=1,
                               padding=0)
        self.bnorm1 = nn.BatchNorm2d(num_features=32)
        self.conv2 = nn.Conv2d(in_channels=32, 
                               out_channels=64, 
                               kernel_size=3, 
                               stride=2,
                               padding=0)
        self.bnorm2 = nn.BatchNorm2d(num_features=64)
        self.conv3 = nn.Conv2d(in_channels=64, 
                               out_channels=128, 
                               kernel_size=3, 
                               stride=2,
                               padding=0)
        self.bnorm3 = nn.BatchNorm2d(num_features=128)
        self.drop = nn.Dropout(p=0.5)
        self.fc = nn.Linear(in_features=512, out_features=10)

    def forward(self, x):
        # STEM
        x = self.conv1(x)
        x = self.bnorm1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        
        x = self.conv2(x)
        x = self.bnorm2(x)
        x = F.relu(x)

        x = self.conv3(x)
        x = self.bnorm3(x)
        x = F.relu(x)

        x = self.drop(x)
        x = torch.flatten(x, 1)

        x = self.fc(x)
        x = F.softmax(x, dim=1)
        return x

model = Net_v2().to(DEVICE)
inp = torch.rand(1, 1, 28, 28).to(DEVICE)
print(model(inp).shape)

torch.Size([1, 10])


In [3]:
# Load weights
import pickle
import io

# https://github.com/pytorch/pytorch/issues/16797#issuecomment-633423219
class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        else: 
            return super().find_class(module, name)

savepath = "./mnist_v2.pkl"

with open(savepath, 'rb') as filehandler:
    contents = CPU_Unpickler(filehandler).load()
    model.load_state_dict(contents['best_weight'])
    print("Minimum loss", contents['best_loss'])
    del contents

Minimum loss 1.5555101667404174


In [7]:
class ModelWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        
        # normalizing
        self.preprocess = transforms.Compose([
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        self.model = model.eval()

    def forward(self, x):
        # the input would be a 280 x 280 x 4 image
        # 4 channels are: red, green, blue, alpha
        # initially the input would be linear 
        
        x = x.reshape(280, 280, 4)
        x = torch.narrow(x, dim=2, start=3, length=1)
        x = x.reshape(1, 1, 280, 280)
        x = F.avg_pool2d(x, 10)
        x = x / 255
        
        x = self.preprocess(x)
        x = self.model(x)
        return x

In [8]:
wrapped_model = ModelWrapper(model)

In [9]:
inp = torch.rand(1, 4, 280, 280,)
print(wrapped_model(inp).shape)

torch.Size([1, 10])


In [11]:
# link:
# https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html

input_names = ["input"]
output_names = ["output"]

torch.onnx.export(
  wrapped_model.eval(),      # the model we want to export
  inp,                       # model input (or a tuple for multiple inputs)
  "../web/src/model.onnx", # file name
  export_params=True,        # store the trained parameter weights inside the model file             #
  do_constant_folding=True,  # whether to execute constant folding for optimization
  verbose=True,
  
  # # These are optional
  input_names=input_names,   # input parameter name(s)
  output_names=output_names,  # output parameter name(s)
  # dynamic_axes={'input' : {0 : 'batch_size',
  #                         1 : 'image_h',
  #                         2 : 'image:w'}
  #              },    # variable length axes
)

verbose: False, log level: Level.ERROR

