In [1]:
import random

In [1]:
from ultralytics import YOLO
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoModelForCausalLM
from PIL import Image
import numpy as np
import pandas as pd
from nltk.translate import bleu_score
from nltk.translate.bleu_score import SmoothingFunction
import torch

yolo_weights_path = "final_wts.pt"

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten').to(device)
trocr_model.config.num_beams = 2

yolo_model = YOLO(yolo_weights_path).to('mps')

print(f'TrOCR and YOLO Models loaded on {device}')

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 16,
  "num_channels": 3,
  "num_hidden_layers": 24,
  "patch_size": 16,
  "qkv_bias": false,
  "transformers_version": "4.46.2"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 1024,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decod

TrOCR and YOLO Models loaded on mps


In [15]:
CONFIDENCE_THRESHOLD = 0.72
BLEU_THRESHOLD = 0.6


def inference(image_path, debug=False, return_texts='final'):
    def get_cropped_images(image_path):
        results = yolo_model(image_path, save=True)
        patches = []
        ys = []
        for box in sorted(results[0].boxes, key=lambda x: x.xywh[0][1]):
            image = Image.open(image_path).convert("RGB")
            x_center, y_center, w, h  = box.xywh[0].cpu().numpy()
            x, y = x_center - w / 2, y_center - h / 2
            cropped_image = image.crop((x, y, x + w, y + h))
            patches.append(cropped_image)
            ys.append(y)
        bounding_box_path = results[0].save_dir + results[0].path[results[0].path.rindex('/'):-4] + '.jpg'
        return patches, ys, bounding_box_path

    def get_model_output(images):
        pixel_values = processor(images=images, return_tensors="pt").pixel_values.to(device)
        output = trocr_model.generate(pixel_values, return_dict_in_generate=True, output_scores=True, max_new_tokens=30, output_logits=True)
        generated_texts = processor.batch_decode(output.sequences, skip_special_tokens=True)
        return generated_texts, output.sequences_scores, output.logits

    def post_process_texts(generated_texts):
        for i in range(len(generated_texts)):
            if len(generated_texts[i]) > 2 and generated_texts[i][:2] == '# ':
                generated_texts[i] = generated_texts[i][2:]

            if len(generated_texts[i]) > 2 and generated_texts[i][-2:] == ' #':
                generated_texts[i] = generated_texts[i][:-2]
        return generated_texts

    def get_qualified_texts(generated_texts, scores, y, logits):
        qualified_texts = []
        for text, score, y_i,  logits_i in zip(generated_texts, scores, y, logits):
            if score > CONFIDENCE_THRESHOLD:
                qualified_texts.append({
                    'text': text,
                    'score': score,
                    'y': y_i,
                    'logits': logits_i
                })
        return qualified_texts

    def get_adjacent_bleu_scores(qualified_texts):
        def get_bleu_score(hypothesis, references):
            weights = [0.5, 0.5]
            smoothing = SmoothingFunction()
            return bleu_score.sentence_bleu(references, hypothesis, weights=weights,
                                            smoothing_function=smoothing.method1)

        for i in range(len(qualified_texts)):
            hyp = qualified_texts[i]['text'].split()
            bleu = 0
            if i < len(qualified_texts) - 1:
                ref = qualified_texts[i + 1]['text'].split()
                bleu = get_bleu_score(hyp, [ref])
            qualified_texts[i]['bleu'] = bleu
        return qualified_texts

    def remove_overlapping_texts(qualified_texts):
        final_texts = []
        new = True
        for i in range(len(qualified_texts)):
            if new:
                final_texts.append(qualified_texts[i])
            else:
                if final_texts[-1]['score'] < qualified_texts[i]['score']:
                    final_texts[-1] = qualified_texts[i]
            new = qualified_texts[i]['bleu'] < BLEU_THRESHOLD
        return final_texts

    cropped_images, y, bounding_box_path = get_cropped_images(image_path)
    if debug:
        print('Number of cropped images:', len(cropped_images))
    generated_texts, scores, logits = get_model_output(cropped_images[:10])
    normalised_scores = np.exp(scores.to('cpu').numpy())
    index = list(range(len(cropped_images)))
    print('Number of generated texts:', len(generated_texts))
    return ",", "", logits
    if return_texts == 'generated':
        return pd.DataFrame({
            'text': generated_texts,
            'score': normalised_scores,
            'y': y,
            'logits': logits,
        })
    generated_texts = post_process_texts(generated_texts)
    if return_texts == 'post_processed':
        return pd.DataFrame({
            'text': generated_texts,
            'score': normalised_scores,
            'y': y,
            'logits': logits,
        })
    qualified_texts = get_qualified_texts(generated_texts, normalised_scores, y, logits)
    if return_texts == 'qualified':
        return pd.DataFrame(qualified_texts)
    qualified_texts = get_adjacent_bleu_scores(qualified_texts)
    if return_texts == 'qualified_with_bleu':
        return pd.DataFrame(qualified_texts)
    final_texts = remove_overlapping_texts(qualified_texts)
    final_texts_df = pd.DataFrame(final_texts, columns=['text', 'score', 'y', 'logits'])
    return final_texts_df, bounding_box_path, logits


image_path = "raw_dataset/p03-112.png"
df, bounding_path, ocr_logits = inference(image_path, debug=False, return_texts='generated')
df


image 1/1 /Users/amaljoe/Desktop/Workspace/IITB/NLP/OCR_with_LLMs/raw_dataset/p03-112.png: 576x640 14 handwritten_lines, 56.0ms
Speed: 26.6ms preprocess, 56.0ms inference, 8.8ms postprocess per image at shape (1, 3, 576, 640)
Results saved to [1m/Users/amaljoe/Desktop/Workspace/IITB/NLP/OCR_with_LLMs/runs/detect/predict31[0m
Number of generated texts: 10


','

tensor([  113,   113, 12557, 12557,  9773,  9773, 13639, 13639, 13639, 13639,   113,   113, 23033, 23033,  1610,  1610,  7325,  7325,  1646,  1646], device='mps:0')
tensor([  440,   440, 21434,   710,  1493,  1493,  4399,    47,  4399,    47, 16734,   113,   101, 11990, 11990,   101,   129,    21,  5606, 20161], device='mps:0')
tensor([ 2156,  2156,    69, 21434,    13,    13,    47,     7,    47,     7,   111, 16734,   932,   101,   101,   932, 22024,   129,   615,   475], device='mps:0')
tensor([   24,    24,   479,    69,    10,    10,     7, 18871,     7, 18871, 27785,    22,     9,   932,   932,     9,   479, 22024,   127,  5999], device='mps:0')
tensor([ 1979,  1979,    22,   479,  1151,  1151, 18871,   101, 18871,   101,    22,    79,     5,     9,     9,     5,    22,   479,   519,   868], device='mps:0')
tensor([   75,    75,   370,    22,    89,    89,   101,    10,   101,    10,    79, 44918,   761,     5,     5,   761,    22,    22,     7,   615], device='mps:0')
tensor([27

In [24]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base")
model = AutoModelForCausalLM.from_pretrained("FacebookAI/roberta-base")

If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`


In [98]:
ocr_logits2 = model(input_ids=tokenizer(["hello", "there you are"], return_tensors="pt", padding=True, truncation=True).input_ids).logits

logits2 = torch.vstack([logit for logit in ocr_logits])

ids = logits2.argmax(axis=-1)
logits2.shape

torch.Size([320, 50265])

In [99]:
dec = tokenizer.decode(ids, skip_special_tokens=True)
print(f'{dec}\n\n')

""sursurbodybodybecbecbecbec""beforebeforebebewaswas1919 No Noprisingur else elseome youome you Nigel" likehavehave like only was61 miserable , , herprising for for you to you to - Nigel anything like like anything joking only enough m it it . her a a to behave to behave ! " of anything anything of . joking myiser wouldn wouldn " . moment moment behave like behave like " she the of of the " . havingable't't You " there there like a like a she gasped kind the the kind " " to enough ! ! sounded You . . a - a - gasped . . kind kind . Well , take my " " like sounded It It - a - a . " I . . I , don my having he he some like doesn doesn a tart a a " I I I don't to almost almost- some't't tart pl tart tart I didn't . take snapped snapped pl . Di Di didn't . It my my , , .ili . .'t mean . . It's " " " " mean to . . . .'s to my my




In [44]:
ocr_ids = ocr_logits.argmax(axis=-1)
output = tokenizer.batch_decode(ocr_ids, skip_special_tokens=True)

AttributeError: 'tuple' object has no attribute 'argmax'

In [1]:
inputs = tokenizer("The capital of France is <mask>.", return_tensors="pt")

with torch.no_grad():
    logits = model(input_ids=inputs.input_ids).logits

# retrieve index of [MASK]
mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
bert_logits = logits[0, mask_token_index]
print(logits[0][-1].shape)

NameError: name 'tokenizer' is not defined

In [38]:
predicted_token_id = logits[0, :].argmax(axis=-1)
tokenizer.decode(predicted_token_id)

'<s>The capital of France is Paris.</s>'

In [None]:
labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
# mask labels of non-[MASK] tokens
labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)

outputs = model(**inputs, labels=labels)
round(outputs.loss.item(), 2)