In [1]:
from transformers import MarianMTModel, MarianTokenizer, VisionEncoderDecoderModel, TrOCRProcessor, AutoTokenizer, \
    AutoModelForCausalLM
import torch
from transformers import AutoTokenizer


device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base")
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten').to(device)
roberta_model = AutoModelForCausalLM.from_pretrained("FacebookAI/roberta-base", is_decoder=True).to(device)

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 16,
  "num_channels": 3,
  "num_hidden_layers": 24,
  "patch_size": 16,
  "qkv_bias": false,
  "transformers_version": "4.46.2"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 1024,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decod

In [21]:
from PIL import Image

# create ids of encoded input vectors
filepath = 'data/ss2.png'
image = Image.open(filepath).convert("RGB")
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)
prev_text = "I can run it later, Now"
roberta_input_ids = tokenizer(prev_text, add_special_tokens=False, return_tensors="pt").input_ids.to(device)
decoder_input_ids = torch.tensor([[processor.tokenizer.cls_token_id]]).to(device)
outputs = trocr_model(pixel_values, decoder_input_ids=decoder_input_ids, return_dict=True)
encoded_sequence = (outputs.encoder_last_hidden_state,)

roberta_input_ids, decoder_input_ids

(tensor([[100,  64, 422,  24, 423,   6, 978]], device='mps:0'),
 tensor([[0]], device='mps:0'))

In [32]:
roberta_logits = roberta_model(input_ids=roberta_input_ids, return_dict=True).logits
trocr_logits = trocr_model(None, encoder_outputs=encoded_sequence, decoder_input_ids=decoder_input_ids, return_dict=True).logits

trocr_confidence = torch.softmax(trocr_logits[:, -1:], dim=-1).flatten().max()

lm_next_logit = (1 - trocr_confidence) * roberta_logits[:, -1:] + trocr_confidence * trocr_logits[:, -1:]

next_decoder_input_ids = torch.argmax(lm_next_logit, axis=-1)
roberta_input_ids = torch.cat([roberta_input_ids, next_decoder_input_ids], axis=-1)
decoder_input_ids = torch.cat([decoder_input_ids, next_decoder_input_ids], axis=-1)

tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True), decoder_input_ids, trocr_confidence

('Test 5-10 evaluation chats ( once I get )',
 tensor([[    0, 34603,   195,    12,   698, 10437, 28975,    36,   683,    38,
            120,  4839]], device='mps:0'),
 tensor(0.9991, device='mps:0', grad_fn=<MaxBackward1>))

In [51]:
trocr_model(None, encoder_outputs=encoded_sequence, decoder_input_ids=decoder_input_ids, output_scores=True,  return_dict=True).keys()

odict_keys(['logits', 'encoder_last_hidden_state'])