# Import lib

In [1]:
import matplotlib.pyplot as plt
from PIL import Image

from vietocr.tool.config import Cfg

import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import log_softmax, softmax

from vietocr.tool.translate import build_model, process_input, translate_beam_search
from vietocr.model.vocab import Vocab

# Define class encoder and decoder for convert

In [2]:
class TextRecognitionEncoder(nn.Module):
    def __init__(self, config):
        super(TextRecognitionEncoder, self).__init__()

        self.device = config['device']
        self.model, _ = build_model(config)
        weights = config['weights']
        self.model.load_state_dict(torch.load(weights, map_location=torch.device(self.device)))

    def forward(self, img):
        """
        src: timestep x batch_size x channel
        hidden: batch_size x hid_dim
        encoder_outputs: src_len x batch_size x hid_dim
        """
        src = self.model.cnn(img)
        encoder_outputs, hidden = self.model.transformer.encoder(src)
        return hidden, encoder_outputs.transpose(0, 1)
    
class TextRecognitionDecoder(nn.Module):
    def __init__(self, config):
        super(TextRecognitionDecoder, self).__init__()

        self.device = config['device']
        self.model, _ = build_model(config)
        weights = config['weights']
        self.model.load_state_dict(torch.load(weights, map_location=torch.device(self.device)))

    def forward(self, tgt, hidden, encoder_outputs):
        """
        tgt: timestep x batch_size 
        hidden: batch_size x hid_dim
        encouder: src_len x batch_size x hid_dim
        output: batch_size x 1 x vocab_size
        """
        tgt = tgt.transpose(0, 1)
        tgt = tgt[-1]
        encoder_outputs = encoder_outputs.transpose(0, 1)
        output, hidden, _ = self.model.transformer.decoder(tgt, hidden, encoder_outputs)
        output = output.unsqueeze(1)
        
        return output, hidden

# Load model and config

In [3]:
config = Cfg.load_config_from_file('../config/vietocr_seq2seq_config.yaml')

config['weights'] = '../checkpoint/vgg_seq2seq.pth'
config['cnn']['pretrained']=False
config['predictor']['beamsearch'] = False
config['device'] = 'cuda'

model_encoder = TextRecognitionEncoder(config=config)
model_decoder = TextRecognitionDecoder(config=config)
model_encoder.eval()
model_decoder.eval()

vocab = Vocab(config['vocab'])

  self.model.load_state_dict(torch.load(weights, map_location=torch.device(self.device)))
  self.model.load_state_dict(torch.load(weights, map_location=torch.device(self.device)))


# Load image

In [4]:
img = Image.open("../asset/test.png")
img = process_input(img, config['dataset']['image_height'], 
                config['dataset']['image_min_width'], config['dataset']['image_max_width'])        
img = img.to(config['device'])

# Inference with torch

In [5]:
def translate_text(model_decoder, hidden, encoder_outputs, device, max_seq_length=128, sos_token=1, eos_token=2):

    with torch.no_grad():

        translated_sentence = [[sos_token]*len(img)]
        char_probs = [[1]*len(img)]

        max_length = 0

        while max_length <= max_seq_length and not all(np.any(np.asarray(translated_sentence).T==eos_token, axis=1)):

            tgt_inp = torch.LongTensor(translated_sentence).to(device)
            
            output, hidden = model_decoder(tgt_inp.transpose(0,1), hidden, encoder_outputs)
            output = softmax(output, dim=-1)
            output = output.to('cpu')

            values, indices  = torch.topk(output, 5)
            
            indices = indices[:, -1, 0]
            indices = indices.tolist()
            
            values = values[:, -1, 0]
            values = values.tolist()
            char_probs.append(values)

            translated_sentence.append(indices)   
            max_length += 1

            del output

        translated_sentence = np.asarray(translated_sentence).T
        
        char_probs = np.asarray(char_probs).T
        char_probs = np.multiply(char_probs, translated_sentence>3)
        char_probs = np.sum(char_probs, axis=-1)/(char_probs>0).sum(-1)
    
    return translated_sentence, char_probs

In [6]:
hidden, encoder_outputs = model_encoder(img)

s, prob = translate_text(model_decoder=model_decoder, hidden=hidden, encoder_outputs=encoder_outputs, device=config['device'])

s = s[0].tolist()
prob = prob[0].tolist()
text = vocab.decode(s)
print(text)

Mặt hàng bán (Hoặc ngành nghề kinh doanh)


# Convert Vietocr

In [7]:
print("Export to ONNX: Encoder of Vietocr")
onnx_path = "../checkpoint/text_recognition_encoder.onnx"
torch.onnx.export(
    model_encoder,                       
    img,                  
    onnx_path,                
    export_params=True,          
    opset_version=15,            
    do_constant_folding=True,     
    input_names=['input_image'],   
    output_names=['hidden', 'encoder_outputs'],  
    dynamic_axes={                 
        'input_image': {0: 'batch_size', 2: 'height', 3: 'width'},
        'hidden': {0: 'batch_size'},
        'encoder_outputs': {0: 'batch_size', 1: 'src_len'}
    }
)

Export to ONNX: Encoder of Vietocr




In [8]:
tgt_inp = torch.full((1, 1), 1, dtype=torch.long, device=config['device'])
print("Export to ONNX: Decoder of Vietocr")
onnx_path = "../checkpoint/text_recognition_decoder.onnx"
torch.onnx.export(
    model_decoder,                         
    (tgt_inp, hidden, encoder_outputs),                 
    onnx_path,                    
    export_params=True,            
    opset_version=15,              
    do_constant_folding=True,      
    input_names=['tgt_inp', 'hidden_input', 'encoder_outputs'],  
    output_names=['output', 'hidden_output'],  
    dynamic_axes={                 
        'tgt_inp': {0: 'batch_size', 1: 'time_step'},
        'hidden_input': {0: 'batch_size'},
        'encoder_outputs': {0: 'batch_size', 1: 'src_len'},
        'output': {0: 'batch_size'},
        'hidden_output': {0: 'batch_size'}
    }
)

Export to ONNX: Decoder of Vietocr


  assert (output == hidden).all()
