#1. Import modules

In [None]:
from torchvision import transforms
from torchvision import datasets
import numpy as np
import PIL
import torch
import torch.nn as nn
import torchvision.datasets as dset

import torch.utils.data
import torchvision


#2. Choose the torch device

In [None]:
if torch.cuda.is_available():
    device = 'cuda:0'
else:
    device = 'cpu'


#2.1. Choose architecture and model checkpoint

In [None]:
import resnet_family
import densenet_family
import inception_family

def load_model(arch_name, num_classes, filename = None, pretrained=True):
    header = []
    if num_classes == 5: header = ['1', '2', '3', '4', '5']
    elif num_classes == 4: header = ['1', '2', '3', '4']
    else: header = ['1', '2']
    
    if "resne" in arch_name:
        frame_size=(224, 224)
        
        if arch_name == "resnet18": cnn_model = resnet_family.resnet18_model(num_classes, pretrained)
        elif arch_name == "resnet34": cnn_model = resnet_family.resnet34_model(num_classes, pretrained)
        elif arch_name == "resnet50": cnn_model = resnet_family.resnet50_model(num_classes, pretrained)
        elif arch_name == "se_resnet18": cnn_model = resnet_family.se_resnet18_model(num_classes, pretrained)
        elif arch_name == "se_resnet34": cnn_model = resnet_family.se_resnet34_model(num_classes, pretrained)
        elif arch_name == "se_resnet50": cnn_model = resnet_family.se_resnet50_model(num_classes, pretrained)
        elif arch_name == "resnext": cnn_model = resnet_family.resnext50_model(num_classes, pretrained)
        
    elif "dense" in arch_name:
        frame_size=(224, 224)
    
        if arch_name == "densenet121": cnn_model = densenet_family.densenet121_model(num_classes, pretrained)
        elif arch_name == "se_densenet121": cnn_model = densenet_family.se_densenet121_model(num_classes, pretrained)
            
    elif "ception" in arch_name:    
        frame_size=(299, 299)
    
        if arch_name == "inception_v3": cnn_model = inception_family.inception_v3(num_classes, pretrained)
        if arch_name == "se_inception_v3": cnn_model = inception_family.se_inception_v3(num_classes, pretrained)
        elif arch_name == "xception": cnn_model = inception_family.xception_model(num_classes, pretrained)
    
    if filename is not None:
        state_dict = torch.load(filename, map_location=torch.device(device))
        cnn_model.load_state_dict(state_dict)

    return cnn_model, frame_size, header


# 3. Prepare converter

In [None]:
class CnnInfer(torch.nn.Module):
    '''
    accepts RGB uint8 image as tensor
    '''
    def __init__(self, cnn_model):
        super(CnnInfer, self).__init__()

        self.mean = [0.5, 0.5, 0.5]
        self.std = [0.5, 0.5, 0.5]   
        self.cnn = cnn_model
        self.cnn.eval()
    
    def forward(self, img):
        x = img.permute(2,0,1).to(torch.float) / 255
        x = torchvision.transforms.functional.normalize(x, self.mean, self.std).unsqueeze(0) # add batch dimension
        y = self.cnn(x)
        return y.squeeze(0) # remove batch dimension


#3. Convert loaded model to script for C++

In [None]:
import json

# resnet18, resnet34, resnet50, se_resnet18, se_resnet34, se_resnet50, resnext
# densenet121, se_densenet121
# inception_v3, se_inception_v3, xception

cnn_model, frame_size, classes_header = load_model("resnet34", 5, "model.ckpt-5");

cnn_infer = CnnInfer(cnn_model)

# создаем скрипт
sample_frame_size = (frame_size[0], frame_size[1], 3)
sample = torch.randint(low=0, high=255, size=sample_frame_size, dtype=torch.uint8)
scripted_model = torch.jit.trace(cnn_infer, sample)
scripted_model.save('scripted_model.pth')

# метаданные (заголовки, размер изображений и т.п.)
metadata = {
    "classes_header": ",".join(classes_header),  
    "input_size": {
        "width": frame_size[1],
        "height": frame_size[0],
        "depth": 3
    }}

with open('metadata.json', 'w') as f:
    f.write(json.dumps(metadata, indent=4))

print("please, check metadata:")
print(json.dumps(metadata, indent=4))


In [None]:
import os
import zipfile

output_filename = "script.zip"

with zipfile.ZipFile(output_filename, 'w') as myzip:
    myzip.write('scripted_model.pth')
    myzip.write('metadata.json')
    