In [2]:
from ultralytics import YOLO
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import numpy as np
import pandas as pd
from nltk.translate import bleu_score
from nltk.translate.bleu_score import SmoothingFunction
import torch

yolo_weights_path = "runs/detect/train103/weights/last.pt"

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten').to(device)
trocr_model.config.num_beams = 2

yolo_model = YOLO(yolo_weights_path).to('mps')

print(f'TrOCR and YOLO Models loaded on {device}')

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 16,
  "num_channels": 3,
  "num_hidden_layers": 24,
  "patch_size": 16,
  "qkv_bias": false,
  "transformers_version": "4.46.2"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 1024,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decod

TrOCR and YOLO Models loaded on mps


In [15]:
CONFIDENCE_THRESHOLD = 0.72
BLEU_THRESHOLD = 0.6


def inference(image_path, debug=False, return_texts='final'):
    def get_cropped_images(image_path):
        results = yolo_model(image_path, save=True)
        patches = []
        ys = []
        for box in sorted(results[0].boxes, key=lambda x: x.xywh[0][1]):
            image = Image.open(image_path).convert("RGB")
            x_center, y_center, w, h  = box.xywh[0].cpu().numpy()
            x, y = x_center - w / 2, y_center - h / 2
            cropped_image = image.crop((x, y, x + w, y + h))
            patches.append(cropped_image)
            ys.append(y)
        return patches, ys
        
    def get_model_output(images):
        pixel_values = processor(images=images, return_tensors="pt").pixel_values.to(device)
        output = trocr_model.generate(pixel_values, return_dict_in_generate=True, output_scores=True, max_new_tokens=30)
        generated_texts = processor.batch_decode(output.sequences, skip_special_tokens=True)
        return generated_texts, output.sequences_scores

    def post_process_texts(generated_texts):
        for i in range(len(generated_texts)):
            if len(generated_texts[i]) > 2 and generated_texts[i][:2] == '# ':
                generated_texts[i] = generated_texts[i][2:]
        return generated_texts

    def get_qualified_texts(generated_texts, scores, y):
        qualified_texts = []
        for text, score, y_i in zip(generated_texts, scores, y):
            if score > CONFIDENCE_THRESHOLD:
                qualified_texts.append({
                    'text': text,
                    'score': score,
                    'y': y_i
                })
        return qualified_texts

    def get_adjacent_bleu_scores(qualified_texts):
        def get_bleu_score(hypothesis, references):
            weights = [0.5, 0.5]
            smoothing = SmoothingFunction()
            return bleu_score.sentence_bleu(references, hypothesis, weights=weights,
                                            smoothing_function=smoothing.method1)

        for i in range(len(qualified_texts)):
            hyp = qualified_texts[i]['text'].split()
            bleu = 0
            if i < len(qualified_texts) - 1:
                ref = qualified_texts[i + 1]['text'].split()
                bleu = get_bleu_score(hyp, [ref])
            qualified_texts[i]['bleu'] = bleu
        return qualified_texts

    def remove_overlapping_texts(qualified_texts):
        final_texts = []
        new = True
        for i in range(len(qualified_texts)):
            if new:
                final_texts.append(qualified_texts[i])
            else:
                if final_texts[-1]['score'] < qualified_texts[i]['score']:
                    final_texts[-1] = qualified_texts[i]
            new = qualified_texts[i]['bleu'] < BLEU_THRESHOLD
        return final_texts

    cropped_images, y = get_cropped_images(image_path)
    if debug:
        print('Number of cropped images:', len(cropped_images))
    generated_texts, scores = get_model_output(cropped_images)
    normalised_scores = np.exp(scores.to('cpu').numpy())
    if return_texts == 'generated':
        return pd.DataFrame({
            'text': generated_texts,
            'score': normalised_scores,
            'y': y
        })
    generated_texts = post_process_texts(generated_texts)
    if return_texts == 'post_processed':
        return pd.DataFrame({
            'text': generated_texts,
            'score': normalised_scores,
            'y': y
        })
    qualified_texts = get_qualified_texts(generated_texts, normalised_scores, y)
    if return_texts == 'qualified':
        return pd.DataFrame(qualified_texts)
    qualified_texts = get_adjacent_bleu_scores(qualified_texts)
    if return_texts == 'qualified_with_bleu':
        return pd.DataFrame(qualified_texts)
    final_texts = remove_overlapping_texts(qualified_texts)
    final_texts_df = pd.DataFrame(final_texts, columns=['text', 'score', 'y'])
    return final_texts_df


image_path = "data/akansha_hw.png"
inference(image_path, debug=False, return_texts='final')


image 1/1 /Users/amaljoe/Desktop/Workspace/IITB/NLP/OCR_with_LLMs/data/akansha_hw.png: 224x640 8 handwritten_lines, 74.1ms
Speed: 7.1ms preprocess, 74.1ms inference, 7.5ms postprocess per image at shape (1, 3, 224, 640)
Results saved to [1m/Users/amaljoe/Desktop/Workspace/IITB/NLP/OCR_with_LLMs/runs/detect/predict11[0m


Unnamed: 0,text,score,y
0,Machine learning is a branch of artificial Int...,0.836044,109.091827
1,on building systems capable of learning from l...,0.918816,335.959808
2,Popular algorithms includes support Vector mac...,0.9373,494.6772
