In [1]:
from transformers import MarianMTModel, MarianTokenizer, VisionEncoderDecoderModel, TrOCRProcessor, AutoTokenizer
import torch
from transformers import AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base")
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten').to(device)

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 16,
  "num_channels": 3,
  "num_hidden_layers": 24,
  "patch_size": 16,
  "qkv_bias": false,
  "transformers_version": "4.46.2"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 1024,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decod

In [5]:
from PIL import Image

# create ids of encoded input vectors
filepath = 'data/ss2.png'
image = Image.open(filepath).convert("RGB")
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)

In [7]:
# create BOS token
decoder_input_ids = tokenizer("[bos]", add_special_tokens=False, return_tensors="pt").input_ids.to(device)
decoder_input_ids

tensor([[10975, 31957,   742]], device='mps:0')

In [21]:
assert decoder_input_ids[0, 0].item() == trocr_model.config.decoder_start_token_id, "`decoder_input_ids` should correspond to `model.config.decoder_start_token_id`"

AssertionError: `decoder_input_ids` should correspond to `model.config.decoder_start_token_id`

In [8]:
# STEP 1

# pass input_ids to encoder and to decoder and pass BOS token to decoder to retrieve first logit
outputs = trocr_model(pixel_values, decoder_input_ids=decoder_input_ids, return_dict=True)

In [10]:
# get encoded sequence
encoded_sequence = (outputs.encoder_last_hidden_state,)
# get logits
lm_logits = outputs.logits

# sample last token with highest prob
next_decoder_input_ids = torch.argmax(lm_logits[:, -1:], axis=-1)

next_decoder_input_ids

tensor([[195]], device='mps:0')

In [11]:
# concat
decoder_input_ids = torch.cat([decoder_input_ids, next_decoder_input_ids], axis=-1)
decoder_input_ids

tensor([[10975, 31957,   742,   195]], device='mps:0')

In [12]:
tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)

'[bos] 5'

In [13]:
# STEP 2

# reuse encoded_inputs and pass BOS + "Ich" to decoder to second logit
lm_logits = trocr_model(None, encoder_outputs=encoded_sequence, decoder_input_ids=decoder_input_ids, return_dict=True).logits

# sample last token with highest prob again
next_decoder_input_ids = torch.argmax(lm_logits[:, -1:], axis=-1)

# concat again
decoder_input_ids = torch.cat([decoder_input_ids, next_decoder_input_ids], axis=-1)
tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)

'[bos] 5-'

In [14]:

# STEP 3
lm_logits = trocr_model(None, encoder_outputs=encoded_sequence, decoder_input_ids=decoder_input_ids, return_dict=True).logits
next_decoder_input_ids = torch.argmax(lm_logits[:, -1:], axis=-1)
decoder_input_ids = torch.cat([decoder_input_ids, next_decoder_input_ids], axis=-1)

# let's see what we have generated so far!
print(f"Generated so far: {tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)}")

# This can be written in a loop as well.


Generated so far: [bos] 5-10
