In [224]:
from pycparser.ply.yacc import token
from ultralytics import YOLO
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoModelForCausalLM, pipeline, AutoModelForMaskedLM
from PIL import Image
import numpy as np
import pandas as pd
from nltk.translate import bleu_score
from nltk.translate.bleu_score import SmoothingFunction
import torch

yolo_weights_path = "final_wts.pt"

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten').to(device)
trocr_model.config.num_beams = 1

yolo_model = YOLO(yolo_weights_path).to('mps')
unmasker_large = pipeline('fill-mask', model='roberta-large', device=device)
roberta_model = AutoModelForMaskedLM.from_pretrained("roberta-large").to(device)

print(f'TrOCR and YOLO Models loaded on {device}')

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 16,
  "num_channels": 3,
  "num_hidden_layers": 24,
  "patch_size": 16,
  "qkv_bias": false,
  "transformers_version": "4.46.2"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 1024,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decod

TrOCR and YOLO Models loaded on mps


In [268]:
CONFIDENCE_THRESHOLD = 0.72
BLEU_THRESHOLD = 0.6


def inference(image_path, debug=False, return_texts='final'):
    def get_cropped_images(image_path):
        results = yolo_model(image_path, save=True)
        patches = []
        ys = []
        for box in sorted(results[0].boxes, key=lambda x: x.xywh[0][1]):
            image = Image.open(image_path).convert("RGB")
            x_center, y_center, w, h  = box.xywh[0].cpu().numpy()
            x, y = x_center - w / 2, y_center - h / 2
            cropped_image = image.crop((x, y, x + w, y + h))
            patches.append(cropped_image)
            ys.append(y)
        bounding_box_path = results[0].save_dir + results[0].path[results[0].path.rindex('/'):-4] + '.jpg'
        return patches, ys, bounding_box_path

    def get_model_output(images):
        pixel_values = processor(images=images, return_tensors="pt").pixel_values.to(device)
        output = trocr_model.generate(pixel_values, return_dict_in_generate=True, output_logits=True, max_new_tokens=30)
        generated_texts = processor.batch_decode(output.sequences, skip_special_tokens=True)
        generated_tokens = [processor.tokenizer.convert_ids_to_tokens(seq) for seq in output.sequences]
        stacked_logits = torch.stack(output.logits, dim=1)
        return generated_texts, stacked_logits, generated_tokens

    def get_scores(logits):
        scores = logits.softmax(-1).max(-1).values.mean(-1)
        return scores

    def post_process_texts(generated_texts):
        for i in range(len(generated_texts)):
            if len(generated_texts[i]) > 2 and generated_texts[i][:2] == '# ':
                generated_texts[i] = generated_texts[i][2:]
                
            if len(generated_texts[i]) > 2 and generated_texts[i][-2:] == ' #':
                generated_texts[i] = generated_texts[i][:-2]
        return generated_texts

    def get_qualified_texts(generated_texts, scores, y, logits, tokens):
        qualified_texts = []
        for text, score, y_i, logits_i, tokens_i in zip(generated_texts, scores, y, logits, tokens):
            if score > CONFIDENCE_THRESHOLD:
                qualified_texts.append({
                    'text': text,
                    'score': score,
                    'y': y_i,
                    'logits': logits_i,
                    'tokens': tokens_i
                })
        return qualified_texts

    def get_adjacent_bleu_scores(qualified_texts):
        def get_bleu_score(hypothesis, references):
            weights = [0.5, 0.5]
            smoothing = SmoothingFunction()
            return bleu_score.sentence_bleu(references, hypothesis, weights=weights,
                                            smoothing_function=smoothing.method1)

        for i in range(len(qualified_texts)):
            hyp = qualified_texts[i]['text'].split()
            bleu = 0
            if i < len(qualified_texts) - 1:
                ref = qualified_texts[i + 1]['text'].split()
                bleu = get_bleu_score(hyp, [ref])
            qualified_texts[i]['bleu'] = bleu
        return qualified_texts

    def remove_overlapping_texts(qualified_texts):
        final_texts = []
        new = True
        for i in range(len(qualified_texts)):
            if new:
                final_texts.append(qualified_texts[i])
            else:
                if final_texts[-1]['score'] < qualified_texts[i]['score']:
                    final_texts[-1] = qualified_texts[i]
            new = qualified_texts[i]['bleu'] < BLEU_THRESHOLD
        return final_texts

    cropped_images, y, bounding_box_path = get_cropped_images(image_path)
    if debug:
        print('Number of cropped images:', len(cropped_images))
    generated_texts, logits, gen_tokens = get_model_output(cropped_images)
    normalised_scores = get_scores(logits)
    if return_texts == 'generated':
        return pd.DataFrame({
            'text': generated_texts,
            'score': normalised_scores,
            'y': y,
        })
    generated_texts = post_process_texts(generated_texts)
    if return_texts == 'post_processed':
        return pd.DataFrame({
            'text': generated_texts,
            'score': normalised_scores,
            'y': y
        })
    qualified_texts = get_qualified_texts(generated_texts, normalised_scores, y, logits, gen_tokens)
    if return_texts == 'qualified':
        return pd.DataFrame(qualified_texts)
    qualified_texts = get_adjacent_bleu_scores(qualified_texts)
    if return_texts == 'qualified_with_bleu':
        return pd.DataFrame(qualified_texts)
    final_texts = remove_overlapping_texts(qualified_texts)
    final_texts_df = pd.DataFrame(final_texts, columns=['text', 'score', 'y'])
    final_tokens = [text['tokens'] for text in final_texts]
    final_logits = [text['logits'] for text in final_texts]
    if return_texts == 'final':
        return final_texts_df
    
    return final_texts_df, bounding_box_path, final_tokens, final_logits


image_path = "raw_dataset/g06-037h.png"
df, bounding_path, tokens, logits = inference(image_path, debug=False, return_texts='final_v2')
df


image 1/1 /Users/amaljoe/Desktop/Workspace/IITB/NLP/OCR_with_LLMs/raw_dataset/g06-037h.png: 576x640 14 handwritten_lines, 487.1ms
Speed: 36.3ms preprocess, 487.1ms inference, 10.2ms postprocess per image at shape (1, 3, 576, 640)
Results saved to [1m/Users/amaljoe/Desktop/Workspace/IITB/NLP/OCR_with_LLMs/runs/detect/predict36[0m


Unnamed: 0,text,score,y
0,"God grant , however , that I may be a false prophet","tensor(0.9806, device='mps:0')",82.707489
1,& that all may go well . Sir R. Peel was,"tensor(0.9002, device='mps:0')",267.147385
2,"here , I understand , but an express task him off","tensor(0.8848, device='mps:0')",436.745956
3,Yesterday . ',"tensor(0.8136, device='mps:0')",610.187439
4,While he was in Naples there had opened a new,"tensor(0.9930, device='mps:0')",792.05722
5,chapter in the history of Anglesey's unceasing,"tensor(0.9791, device='mps:0')",994.093018
6,search for an effective alleviation of his painful,"tensor(0.9445, device='mps:0')",1147.620361
7,absolutely . None of the numerous conventional,"tensor(0.9200, device='mps:0')",1330.609741
8,to remedies to which he had been subjected ever,"tensor(0.9066, device='mps:0')",1513.878235
9,Since the symptoms had first shown them -,"tensor(0.9270, device='mps:0')",1678.105103


In [267]:
def get_new_logits(tokens):
    inputs = tokens.reshape(1, -1)
    # Get the logits from the model
    with torch.no_grad():
        outputs = roberta_model(input_ids=inputs, attention_mask=torch.ones(inputs.shape).to(device))
        logits = outputs.logits


    logits_flattened = logits.reshape(-1, slogits.shape[-1])
    print(processor.batch_decode([logits_flattened.argmax(-1)], skip_special_tokens=True))
    return logits.reshape(tokens.shape + (logits.shape[-1],))


slogits = torch.stack([logit for logit in logits], dim=0)
tokens = slogits.argmax(-1)
confidence = slogits.softmax(-1).max(-1).values
indices = torch.where(confidence < 0.5)
# put 50264(mask) when confidence < 0.5
for i, j in zip(indices[0], indices[1]):
    if i != 6:
        continue
    tokens[i, j] = torch.tensor(50264)

new_logits = get_new_logits(tokens)

["# , however , that I may be a false prophet# & that all may go well . Sir R. Peel washere , I understand , but an express carried him offhere . 'While he was in Naples there had opened a newchapter in. history of Anglesey's unceasingsearch for an effective alleviation of his pain.effects . None of the numerousto remedies to which he had been subjectedSince the symptoms had first shown themselves -# - seventeen years before had had the slightest effect ."]


In [250]:
for i, j in zip(indices[0], indices[1]):
    slogits[i, j] = slogits[i, j] * 0.5 + new_logits[i, j] * 0.5

logits_flattened = slogits.reshape(-1, slogits.shape[-1])
processor.batch_decode([logits_flattened.argmax(-1)], skip_special_tokens=True)

["God grant , however , that I may be a false prophet# and that all may go well . Sir R. Peel was# , I understand , but an express task him offand . 'While he was in Naples there had opened a newchapter in the history of Anglesey's unceasingsearch for an effective alleviation of his painful# . None of the numerous conventionalmedical remedies to which he had been subjected eversince the symptoms had first shown them -#. seventeen years before had had the slightest effect ."]

In [208]:
processor.tokenizer.encode("")
processor.tokenizer.decode([0, 0])

'<s><s>'

In [172]:
def get_cropped_images(image_path):
    results = yolo_model(image_path, save=True)
    patches = []
    ys = []
    for box in sorted(results[0].boxes, key=lambda x: x.xywh[0][1]):
        image = Image.open(image_path).convert("RGB")
        x_center, y_center, w, h  = box.xywh[0].cpu().numpy()
        x, y = x_center - w / 2, y_center - h / 2
        cropped_image = image.crop((x, y, x + w, y + h))
        patches.append(cropped_image)
        ys.append(y)
    bounding_box_path = results[0].save_dir + results[0].path[results[0].path.rindex('/'):-4] + '.jpg'
    return patches, ys, bounding_box_path

def get_model_output(images):
    pixel_values = processor(images=images, return_tensors="pt").pixel_values.to(device)
    output = trocr_model.generate(pixel_values, return_dict_in_generate=True, output_logits=True, max_new_tokens=30)
    generated_texts = processor.batch_decode(output.sequences, skip_special_tokens=True)
    generated_tokens = [processor.tokenizer.convert_ids_to_tokens(seq) for seq in output.sequences]
    logits = torch.stack(output.logits, dim=1)
    return generated_texts, logits, generated_tokens


image_path = "data/FML_whiteboard2.png"
cropped_images, y, bounding_box_path = get_cropped_images(image_path)
generated_texts, logits, gen_tokens = get_model_output(cropped_images)
for i in range(len(generated_texts)):
    print(generated_texts[i], logits[i].softmax(-1).max(-1).values.mean())


image 1/1 /Users/amaljoe/Desktop/Workspace/IITB/NLP/OCR_with_LLMs/data/FML_whiteboard2.png: 384x640 8 handwritten_lines, 109.5ms
Speed: 6.8ms preprocess, 109.5ms inference, 10.9ms postprocess per image at shape (1, 3, 384, 640)
Results saved to [1m/Users/amaljoe/Desktop/Workspace/IITB/NLP/OCR_with_LLMs/runs/detect/predict34[0m
K-means clustering algorithm tensor(0.9335, device='mps:0')
Assume we have K clusters of points ; each point in a cluster . tensor(0.9584, device='mps:0')
Assume we have K clusters of points ; each point in a cluster . tensor(0.9604, device='mps:0')
is closest to its centroid ( more than any other cluster centroid ) tensor(0.9554, device='mps:0')
If cluster assignment is known , it is easy to compute the centriots . tensor(0.9374, device='mps:0')
It cluster assignment is known , it is easy to compute the centroids . tensor(0.9117, device='mps:0')
If cluster centrids are known , it is easy to do cluster assignment . tensor(0.8591, device='mps:0')
How do we solv

In [122]:
def get_scores(logits):
    stacked_logits = torch.stack(logits, dim=1)
    scores = stacked_logits.softmax(-1).max(-1).values.mean(-1)
    return scores

get_scores(logits).shape

torch.Size([8])

In [155]:
generated_texts

[['',
  'K',
  '-',
  'me',
  'ans',
  ' clust',
  'ering',
  ' algorithm',
  '',
  '',
  '',
  '',
  '',
  '',
  '',
  '',
  '',
  '',
  '',
  ''],
 ['',
  'Ass',
  'ume',
  ' we',
  ' have',
  ' K',
  ' clusters',
  ' of',
  ' points',
  ' ;',
  ' each',
  ' point',
  ' in',
  ' a',
  ' cluster',
  ' .',
  '',
  '',
  '',
  ''],
 ['',
  'Ass',
  'ume',
  ' we',
  ' have',
  ' K',
  ' clusters',
  ' of',
  ' points',
  ' ;',
  ' each',
  ' point',
  ' in',
  ' a',
  ' cluster',
  ' .',
  '',
  '',
  '',
  ''],
 ['',
  'is',
  ' closest',
  ' to',
  ' its',
  ' cent',
  'roid',
  ' (',
  ' more',
  ' than',
  ' any',
  ' other',
  ' cluster',
  ' cent',
  'roid',
  ' )',
  '',
  '',
  '',
  ''],
 ['',
  'If',
  ' cluster',
  ' assignment',
  ' is',
  ' known',
  ' ,',
  ' it',
  ' is',
  ' easy',
  ' to',
  ' compute',
  ' the',
  ' cent',
  'riots',
  ' .',
  '',
  '',
  '',
  ''],
 ['',
  'It',
  ' cluster',
  ' assignment',
  ' is',
  ' known',
  ' ,',
  ' it',
  ' is',
  ' easy',
 

In [146]:
text = " ".join(generated_texts)
text

'K-means clustering algorithm Assume we have K clusters of points ; each point in a cluster . Assume we have K clusters of points ; each point in a cluster . is closest to its centroid ( more than any other cluster centroid ) If cluster assignment is known , it is easy to compute the centriots . It cluster assignment is known , it is easy to compute the centroids . If cluster centrids are known , it is easy to do cluster assignment . How do we solve this chicken-egg problem ? Fix one , optimize the other !'

In [162]:
confidence = logits.softmax(-1).max(-1).values
mask_indices = torch.where(confidence < 0.5)

for y, x in zip(mask_indices[0], mask_indices[1]):
    for i in range(x, )

 cent
 cent
riots
 .



IndexError: string index out of range

In [89]:
stacked_logits = torch.stack(logits, dim=1)
processor.batch_decode([stacked_logits[-2].argmax(-1)[:5]], skip_special_tokens=True)

['If cluster centrids']

In [71]:
stacked_logits[-2].softmax(-1).max(-1)

torch.return_types.max(
values=tensor([0.5884, 0.9979, 0.3797, 0.8033, 0.7809, 1.0000, 1.0000, 0.9999, 0.9997, 1.0000, 0.9999, 0.9966, 0.9997, 0.2181, 0.9184, 0.6810, 0.9987, 0.9935, 0.9677], device='mps:0'),
indices=tensor([ 1106, 18016,   715,   338,  7823,    32,   684,  2156,    24,    16,  1365,     7,   109, 18016, 11717,   479,     2,     2,     2], device='mps:0'))

In [117]:
res = unmasker_large("""K Means clustering algorithm
Assume we have K cluster of points; each point in a cluster
Is closest to its centroid (more than any other cluster centroid)
If cluster assignment is known, it is easy to compute the centroid
If cluster <mask> is known, it is easy to do cluster assignment
How do we solve this chicken-egg problem? Fix one, optimize the other!""", top_k=10)
res

[{'score': 0.31558406352996826,
  'token': 11717,
  'token_str': ' assignment',
  'sequence': 'K Means clustering algorithm\nAssume we have K cluster of points; each point in a cluster\nIs closest to its centroid (more than any other cluster centroid)\nIf cluster assignment is known, it is easy to compute the centroid\nIf cluster assignment is known, it is easy to do cluster assignment\nHow do we solve this chicken-egg problem? Fix one, optimize the other!'},
 {'score': 0.06008317694067955,
  'token': 15229,
  'token_str': ' composition',
  'sequence': 'K Means clustering algorithm\nAssume we have K cluster of points; each point in a cluster\nIs closest to its centroid (more than any other cluster centroid)\nIf cluster assignment is known, it is easy to compute the centroid\nIf cluster composition is known, it is easy to do cluster assignment\nHow do we solve this chicken-egg problem? Fix one, optimize the other!'},
 {'score': 0.03243965655565262,
  'token': 22432,
  'token_str': ' ali

In [110]:
cluster_pred = stacked_logits[-2].softmax(-1)[2]
res = unmasker_large("""K Means clustering algorithm
Assume we have K cluster of points; each point in a cluster
Is closest to its centroid (more than any other cluster centroid)
If cluster assignment is known, it is easy to compute the centroid
If cluster <mask> is known, it is easy to do cluster assignment""")


for pred in res:
    score, token, str = pred['score'], pred['token'], pred['token_str']
    confidence = score + cluster_pred[token]
    print(str, confidence, cluster_pred[token])

 assignment tensor(0.7182, device='mps:0') tensor(6.8814e-10, device='mps:0')
 alignment tensor(0.0231, device='mps:0') tensor(2.5890e-10, device='mps:0')
 identity tensor(0.0183, device='mps:0') tensor(3.6747e-07, device='mps:0')
 composition tensor(0.0157, device='mps:0') tensor(2.2282e-07, device='mps:0')
 orientation tensor(0.0120, device='mps:0') tensor(1.8597e-07, device='mps:0')
