In [1]:
from transformers import MarianMTModel, MarianTokenizer, VisionEncoderDecoderModel, TrOCRProcessor, AutoTokenizer
import torch
from transformers import AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base")
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten').to(device)

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 16,
  "num_channels": 3,
  "num_hidden_layers": 24,
  "patch_size": 16,
  "qkv_bias": false,
  "transformers_version": "4.46.2"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 1024,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decod

In [78]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForCausalLM
import torch

roberta_model = AutoModelForCausalLM.from_pretrained("FacebookAI/roberta-base", is_decoder=True).to(device)

In [3]:
from PIL import Image

# create ids of encoded input vectors
filepath = 'data/ss2.png'
image = Image.open(filepath).convert("RGB")
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)

In [38]:
# create BOS token


tensor([[ 100,   64,  422,   24,  423,    4,  978, 1437]], device='mps:0')

In [88]:
#roberta loop

decoder_input_ids = tokenizer("Capital of France is ", add_special_tokens=False, return_tensors="pt").input_ids.to(device)
decoder_input_ids

lm_logits = roberta_model(input_ids=decoder_input_ids, return_dict=True).logits

# sample last token with highest prob again
next_decoder_input_ids = torch.argmax(lm_logits[:, -1:], axis=-1)

# concat again
decoder_input_ids = torch.cat([decoder_input_ids, next_decoder_input_ids], axis=-1)
tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True), decoder_input_ids

('Capital of France is ',
 tensor([[38632,     9,  1470,    16,  1437,     2]], device='mps:0'))

In [21]:
assert decoder_input_ids[0, 0].item() == trocr_model.config.decoder_start_token_id, "`decoder_input_ids` should correspond to `model.config.decoder_start_token_id`"

AssertionError: `decoder_input_ids` should correspond to `model.config.decoder_start_token_id`

In [30]:
# STEP 1

# pass input_ids to encoder and to decoder and pass BOS token to decoder to retrieve first logit
outputs = trocr_model(pixel_values, decoder_input_ids=decoder_input_ids, return_dict=True)

In [31]:
# get encoded sequence
encoded_sequence = (outputs.encoder_last_hidden_state,)
# get logits
lm_logits = outputs.logits

# sample last token with highest prob
next_decoder_input_ids = torch.argmax(lm_logits[:, -1:], axis=-1)

next_decoder_input_ids

tensor([[2]], device='mps:0')

In [32]:
# concat
decoder_input_ids = torch.cat([decoder_input_ids, next_decoder_input_ids], axis=-1)
decoder_input_ids

tensor([[  100,    33,    57,  3044,    70,   363,     4,   978,    47,   646,
         31957,   742,     2]], device='mps:0')

In [33]:
tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)

'I have been testing all night. Now you [bos]'

In [37]:
# STEP 2

# reuse encoded_inputs and pass BOS + "Ich" to decoder to second logit
lm_logits = trocr_model(None, encoder_outputs=encoded_sequence, decoder_input_ids=decoder_input_ids, return_dict=True).logits

# sample last token with highest prob again
next_decoder_input_ids = torch.argmax(lm_logits[:, -1:], axis=-1)

# concat again
decoder_input_ids = torch.cat([decoder_input_ids, next_decoder_input_ids], axis=-1)
tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True), decoder_input_ids

('I have been testing all night. Now you [bos] get )',
 tensor([[  100,    33,    57,  3044,    70,   363,     4,   978,    47,   646,
          31957,   742,     2,   120,  4839,     2,     2]], device='mps:0'))

In [14]:

# STEP 3
lm_logits = trocr_model(None, encoder_outputs=encoded_sequence, decoder_input_ids=decoder_input_ids, return_dict=True).logits
next_decoder_input_ids = torch.argmax(lm_logits[:, -1:], axis=-1)
decoder_input_ids = torch.cat([decoder_input_ids, next_decoder_input_ids], axis=-1)

# let's see what we have generated so far!
print(f"Generated so far: {tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)}")

# This can be written in a loop as well.


Generated so far: [bos] 5-10


In [115]:
trocr_model.config.num_beams = 2

In [132]:
def get_model_output(images):
    pixel_values = processor(images=images, return_tensors="pt").pixel_values.to(device)
    output = trocr_model.generate(pixel_values, return_dict_in_generate=True, output_scores=True, max_new_tokens=30, output_logits=True)
    
    generated_texts = processor.batch_decode(output.sequences, skip_special_tokens=True)
    return generated_texts, output.sequences_scores, output

img = Image.open("data/fml_line.png").convert("RGB")
_, _, output = get_model_output([img])

In [133]:
for logit in output.logits:    
    confidence = logit.softmax(-1).max()
    word = tokenizer.decode(logit.argmax(-1))
    print(word, confidence)

IfIf tensor(0.9891, device='mps:0')
 cluster cluster tensor(0.9995, device='mps:0')
 controls controls tensor(0.4203, device='mps:0')
 arer tensor(1.0000, device='mps:0')
 knownids tensor(1.0000, device='mps:0')
 , are tensor(1.0000, device='mps:0')
 it known tensor(1.0000, device='mps:0')
 is , tensor(1.0000, device='mps:0')
 easy it tensor(1.0000, device='mps:0')
 to is tensor(1.0000, device='mps:0')
 do easy tensor(1.0000, device='mps:0')
" to tensor(0.9972, device='mps:0')
 do cluster tensor(0.9992, device='mps:0')
" argument tensor(0.7471, device='mps:0')
 cluster charter tensor(0.4367, device='mps:0')
 argument argument tensor(0.7926, device='mps:0')
</s></s> tensor(0.9955, device='mps:0')
</s></s> tensor(0.9997, device='mps:0')


In [None]:
FML_text = """K Means clustering algorithm
Assume we have K cluster of points; each point in a cluster
Is closest to its centroid (more than any other cluster centroid)
If cluster assignment is known, it is easy to compute the centroid"""

