## 1. libraries

In [1]:
import torch
import torchvision
import onnx
import onnxruntime as ort
from torchvision import transforms as T
from PIL import Image
import os
from tqdm import tqdm
preprocess_parseq = T.Compose([
            T.Resize((32, 128), T.InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(0.5, 0.5)
        ])

In [2]:
def main(batch_size, device_type = "cpu", img_folder = "/home/ubuntu/parseq/demo_images"):
    if device_type == "cpu":
        device = torch.device("cpu")
        parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval()
    elif device_type == "cuda":
        device = torch.device("cuda:0")
        parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval().cuda()

    test_img_path = "demo_images/art-01107.jpg"
    img = Image.open(test_img_path).convert('RGB')
    # Preprocess. Model expects a batch of images with shape: (B, C, H, W)
    img = preprocess_parseq(img.convert('RGB')).unsqueeze(0).to(device)

    logits = parseq(img)
    logits.shape  # torch.Size([1, 26, 95]), 94 characters + [EOS] symbol

    # Greedy decoding
    pred = logits.softmax(-1)
    label, confidence = parseq.tokenizer.decode(pred)
    print(confidence)
    print('Decoded label = {}'.format(label[0]))

    dummy_input = torch.rand(batch_size, 3, 32, 128) 

    output_path = "parseq_{}_torchscript.pt"

    # To ONNX
    parseq.to_torchscript(file_path=output_path, method="script", example_inputs=dummy_input)  # opset v14 or newer is required
    
    ts_model = torch.jit.load(output_path)
    logits = ts_model(img)
    logits.shape  # torch.Size([1, 26, 95]), 94 characters + [EOS] symbol
    print(logits.shape)
    # Greedy decoding
    pred = logits.softmax(-1)
    label, confidence = parseq.tokenizer.decode(pred)
    print('Decoded label = {}'.format(label[0]))

In [13]:
def main_batch(batch_size=1, device_type="cpu", img_folder="/home/ubuntu/parseq/digits_demo", img_type="jpg"):
    if device_type == "cpu":
        device = torch.device("cpu")
        parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval()
    elif device_type == "cuda":
        device = torch.device("cuda:0")
        parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval().cuda()

    test_img_path = "demo_images/art-01107.jpg"
    img = Image.open(test_img_path).convert('RGB')
    # Preprocess. Model expects a batch of images with shape: (B, C, H, W)
    img = preprocess_parseq(img.convert('RGB')).unsqueeze(0).to(device)

    logits = parseq(img)
    logits.shape  # torch.Size([1, 26, 95]), 94 characters + [EOS] symbol

    # Greedy decoding
    pred = logits.softmax(-1)
    label, confidence = parseq.tokenizer.decode(pred)
    print(confidence)
    print('Decoded label = {}'.format(label[0]))

    dummy_input = torch.rand(batch_size, 3, 32, 128) 

    output_path = "parseq_{}_torchscript.pt".format(batch_size)

    # To ONNX
    parseq.to_torchscript(file_path=output_path, method="script", example_inputs=dummy_input)  # opset v14 or newer is required
    
    ts_model = torch.jit.load(output_path)
    logits = ts_model(img)
    logits.shape  # torch.Size([1, 26, 95]), 94 characters + [EOS] symbol
    print(logits.shape)
    # Greedy decoding
    pred = logits.softmax(-1)
    label, confidence = parseq.tokenizer.decode(pred)
    print('Decoded label = {}'.format(label[0]))
    c = 0
    x = 0
    if os.path.exists(img_folder):
        for filename in tqdm(os.listdir(img_folder)):
            if filename.endswith(img_type):
                x += 1
                img_path = os.path.join(img_folder, filename)
                # print(img_path)
                
                img_input = Image.open(img_path).convert('RGB')
                # Preprocess. Model expects a batch of images with shape: (B, C, H, W)
                img_input = preprocess_parseq(img_input.convert('RGB')).unsqueeze(0).to(device)
                # print(img_input)
                with torch.no_grad():  
                    logits_ts = ts_model(img_input)
                    logits = parseq(img_input)

                pred = logits_ts.softmax(-1)
                label_ts, confidence = parseq.tokenizer.decode(pred)
                # print(confidence)
                # print('Decoded label = {}'.format(label_ts[0]))
                
                pred = logits.softmax(-1)
                label, confidence = parseq.tokenizer.decode(pred)
                # print(confidence)
                # print('Decoded label = {}'.format(label[0]))
                
                
                if label[0] == label_ts[0]:
                    print("matched")
                    c += 1
                else:
                    print(label[0])
                    print(label_ts[0])
                    # c += 1
    print(c, x)
                
                # reading = "".join([label[0][i] for i in range(len(label[0])) if confidence[0][i] > eps])
                # print(reading)
                # if len(reading) > 0:
                #     outfilename = os.path.join(img_path.replace(img_type, "txt"))
                #     print(outfilename)
                #     outfile = open(outfilename, "w")
                #     outfile.write(reading)
                #     outfile.close()
                # return reading

In [14]:
main_batch()

Using cache found in /home/ubuntu/.cache/torch/hub/baudm_parseq_main


[tensor([0.9979, 0.9996, 0.9999, 0.9998, 0.9868, 0.9998, 0.9432, 0.9765, 0.9937,
        0.9981], grad_fn=<SliceBackward0>)]
Decoded label = CHEWBACCA


  " but it is a non-constant {}. Consider removing it.".format(name, hint))


torch.Size([1, 26, 95])
Decoded label = CHEWBACCA


  1%|▏         | 1/75 [00:01<01:19,  1.08s/it]

matched


  4%|▍         | 3/75 [00:01<00:28,  2.57it/s]

matched
matched


  7%|▋         | 5/75 [00:01<00:16,  4.24it/s]

matched
matched


  9%|▉         | 7/75 [00:02<00:12,  5.46it/s]

matched
matched


 12%|█▏        | 9/75 [00:02<00:11,  5.70it/s]

matched
matched


 15%|█▍        | 11/75 [00:02<00:10,  6.35it/s]

matched
matched


 17%|█▋        | 13/75 [00:02<00:09,  6.74it/s]

matched
matched


 20%|██        | 15/75 [00:03<00:08,  7.03it/s]

matched
matched


 23%|██▎       | 17/75 [00:03<00:08,  7.14it/s]

matched
matched


 25%|██▌       | 19/75 [00:03<00:07,  7.16it/s]

matched
matched


 29%|██▉       | 22/75 [00:04<00:06,  8.70it/s]

matched
matched


 32%|███▏      | 24/75 [00:04<00:06,  7.88it/s]

matched
matched


 35%|███▍      | 26/75 [00:04<00:06,  7.52it/s]

matched
matched


 37%|███▋      | 28/75 [00:04<00:06,  7.31it/s]

matched
matched


 40%|████      | 30/75 [00:05<00:08,  5.01it/s]

matched
matched


 43%|████▎     | 32/75 [00:05<00:07,  5.93it/s]

matched
matched


 45%|████▌     | 34/75 [00:06<00:06,  6.54it/s]

matched
matched


 48%|████▊     | 36/75 [00:06<00:05,  6.89it/s]

matched
matched


 51%|█████     | 38/75 [00:06<00:05,  7.06it/s]

matched
matched


 53%|█████▎    | 40/75 [00:06<00:04,  7.12it/s]

matched
matched


 56%|█████▌    | 42/75 [00:07<00:04,  7.16it/s]

matched
matched


 59%|█████▊    | 44/75 [00:07<00:04,  7.00it/s]

matched
matched


 61%|██████▏   | 46/75 [00:07<00:04,  7.12it/s]

matched
matched


 64%|██████▍   | 48/75 [00:07<00:03,  7.14it/s]

matched
matched


 68%|██████▊   | 51/75 [00:08<00:02,  9.21it/s]

matched
matched


 71%|███████   | 53/75 [00:08<00:02,  8.09it/s]

matched
matched


 73%|███████▎  | 55/75 [00:08<00:02,  7.66it/s]

matched
matched


 76%|███████▌  | 57/75 [00:09<00:02,  7.43it/s]

matched
matched


 79%|███████▊  | 59/75 [00:09<00:02,  7.32it/s]

matched
matched


 81%|████████▏ | 61/75 [00:09<00:01,  7.07it/s]

matched
matched


 84%|████████▍ | 63/75 [00:10<00:01,  6.34it/s]

matched
matched


 87%|████████▋ | 65/75 [00:10<00:01,  5.90it/s]

matched
matched


 89%|████████▉ | 67/75 [00:10<00:01,  5.76it/s]

matched
matched


 92%|█████████▏| 69/75 [00:11<00:01,  5.73it/s]

matched
matched


 95%|█████████▍| 71/75 [00:11<00:00,  5.68it/s]

matched
matched


 97%|█████████▋| 73/75 [00:11<00:00,  5.64it/s]

matched
matched


100%|██████████| 75/75 [00:12<00:00,  6.17it/s]

matched
matched
73 73





In [None]:
print()