## Config

In [None]:
import matplotlib.pyplot as plt
from PIL import Image

from vietocr.tool.config import Cfg
from vietocr.model.trainer import Trainer

In [None]:
config = Cfg.load_config_from_name('vgg_transformer')
dataset_params = {
    'name':'imgVietocr',
    'data_root':'/mnt/disk1/mbbank/OCR/DATA/data_quangnd/new_train',
    'train_annotation':'/mnt/disk1/mbbank/OCR/DATA/team/train.txt',
    'valid_annotation':'/mnt/disk1/mbbank/OCR/DATA/team/val.txt'
}

params = {
         'print_every':200,
         'valid_every':2*200,
          'iters':2000000,
          'checkpoint':'/mnt/disk1/mbbank/OCR/CODE/VietOcr/weight/vietocr_V1.pth',
          'export':'/mnt/disk1/mbbank/OCR/CODE/VietOcr/weight/vietocr_V2.pth',
          'metrics': 150
         }

config['trainer'].update(params)
config['dataset'].update(dataset_params)
config['vocab'] += '–' + 'ü' + 'ā' + 'ö' # Ko cần dòng này, nếu cần thì thêm các kí tự
config['device'] = 'cuda:0'
config['optimizer']['max_lr'] = 0.00005

config

## Train

In [None]:
trainer = Trainer(config, pretrained=False)
trainer.config.save('/mnt/disk1/mbbank/OCR/CODE/VietOcr/vietocr/config/config.yml')

In [None]:
# trainer.visualize_dataset()
trainer.train()

In [None]:
trainer.precision()

In [None]:
trainer.visualize_prediction()

In [None]:
with torch.no_grad():
    torch.cuda.empty_cache()

## Infer

In [None]:
from vietocr.tool.predictor import Predictor
from vietocr.tool.config import Cfg

import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch

class RECOGNIZE():
    def __init__(self, weight_path='./weight/vietocr_v1.pth', device='cpu') -> None:
        config = Cfg.load_config_from_name('vgg_transformer')
        config['weights'] = weight_path
        config['cnn']['pretrained'] = False
        config['device'] = device
        config['predictor']['beamsearch'] = False
        config['vocab'] = config['vocab'] + '–' + 'ü' + 'ā' + 'ö'
        self.config = config
        self.detector = Predictor(config)
    
    def predict_image(self, img_path):
        img = Image.open(img_path).convert('RGB')
        words_predicted = self.detector.predict(img)
        return words_predicted

In [None]:
model_v1 = RECOGNIZE(weight_path='/mnt/disk1/mbbank/OCR/CODE/VietOcr/weight/vietocr_V1.pth', device='cuda:2')


v1_predict = {}
org_test = '/mnt/disk1/mbbank/OCR/DATA/data_quangnd/test'
bar = tqdm(os.listdir(org_test))
for img_path in  bar:
    v1_predict[img_path] = model_v1.predict_image(org_test + '/' + img_path)

with open('/mnt/disk1/mbbank/OCR/CODE/VietOcr/result/vietocr_V1.txt', 'w') as f:
    for key, value in v1_predict.items():
        f.write('%s\t%s\n' % (key, value))

In [None]:
img_path = org_test + '/' + 'public_test_img_10774.jpg'
image = Image.open(img_path).convert('RGB')
plt.imshow(image)
plt.title(v1_predict['public_test_img_10774.jpg'])