In [None]:
import torch
import glob as glob
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm
from dataclasses import dataclass
from transformers import (VisionEncoderDecoderModel,TrOCRProcessor)

plt.rcParams['figure.figsize'] = (12, 9)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
@dataclass(frozen=True)
class TrainingConfig:
    BATCH_SIZE:    int = 48
    EPOCHS:        int = 35
    LEARNING_RATE: float = 0.00005

@dataclass(frozen=True)
class DatasetConfig:
    DATA_ROOT:     str = 'licence_plate'

@dataclass(frozen=True)
class ModelConfig:
    MODEL_NAME: str = 'microsoft/trocr-small-printed'

In [None]:
model_full_path = '/home/foziljon/npr/seq2seq_model_printed/checkpoint-xxxx'
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
trained_model = VisionEncoderDecoderModel.from_pretrained(model_full_path, local_files_only=True).to(device)

In [None]:
def ocr(image, processor, model):
    pixel_values = processor(image, return_tensors='pt').pixel_values.to(device)
    generated_ids = model.generate(pixel_values)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text

In [None]:
def eval_new_data(num_samples=50):
    image_paths = glob.glob("/home/foziljon/npr/licence_plate/lp_test/*")
    for i, image_path in tqdm(enumerate(image_paths), total=len(image_paths)):
        if i == num_samples:
            break
        image = Image.open(image_path).convert('RGB')
        text = ocr(image, processor, trained_model)
        plt.figure(figsize=(7, 4))
        plt.imshow(image)
        plt.title(text)
        plt.axis('off')
        plt.show()

In [None]:
eval_new_data(num_samples=100)