#1. Import modules

In [None]:
from torchvision import transforms
from torchvision import datasets
import numpy as np
import PIL
import torch
import torch.nn as nn
import torchvision.datasets as dset

import torch.utils.data
import torchvision


#2. Choose the torch device

In [None]:
if torch.cuda.is_available():
    torch_device = 'cuda:0'
else:
    torch_device = 'cpu'


# 3. Prepare converter

In [None]:
class CnnInfer(torch.nn.Module):
    '''
    accepts RGB uint8 image as tensor
    '''
    def __init__(self, cnn_model):
        super(CnnInfer, self).__init__()

        self.mean = [0.5, 0.5, 0.5]
        self.std = [0.5, 0.5, 0.5]   
        self.cnn = cnn_model
        self.cnn.eval()
    
    def forward(self, img):
        x = img.permute(2,0,1).to(torch.float) / 255
        x = torchvision.transforms.functional.normalize(x, self.mean, self.std).unsqueeze(0) # add batch dimension
        y = self.cnn(x)
        return y.squeeze(0) # remove batch dimension


#3. Convert loaded model to script for C++

In [None]:
import json
import utils

# available architectures:
# - densenet121, se_densenet121
# - resnet18, resnet34, resnet50, se_resnet18, se_resnet34, se_resnet50, resnext50, se_resnext50
# - inception_v3, se_inception_v3, xception, se_xception, inception_resnet

cnn_model, frame_size, classes_header = utils.load_model("se_xception", 5, "models/se_xception.ckpt", device=torch_device);

cnn_infer = CnnInfer(cnn_model)

# создаем скрипт
sample_frame_size = (frame_size[0], frame_size[1], 3)
sample = torch.randint(low=0, high=255, size=sample_frame_size, dtype=torch.uint8)
# scripted_model = torch.jit.script(cnn_infer, sample)
scripted_model = torch.jit.trace(cnn_infer, sample)
scripted_model.save('scripted_model.pth')

# метаданные (заголовки, размер изображений и т.п.)
metadata = {
    "classes_header": ",".join(classes_header),  
    "input_size": {
        "width": frame_size[1],
        "height": frame_size[0],
        "depth": 3
    }}

with open('metadata.json', 'w') as f:
    f.write(json.dumps(metadata, indent=4))

print("please, check metadata:")
print(json.dumps(metadata, indent=4))


In [None]:
import os
import zipfile

output_filename = "script.zip"

with zipfile.ZipFile(output_filename, 'w') as myzip:
    myzip.write('scripted_model.pth')
    myzip.write('metadata.json')
    