In [1]:
from transformers import VisionEncoderDecoderModel, ElectraConfig, CamembertTokenizer, ElectraForCausalLM
from torch import nn
import torch

  from .autonotebook import tqdm as notebook_tqdm


#### Setup Model

In [None]:
VisionEncoderDecoderModel.from_pretrained('openthaigpt/thai-trocr')


Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": false,
  "transformers_version": "4.46.2"
}

Config of the decoder: <class 'transformers.models.electra.modeling_electra.ElectraForCausalLM'> is overwritten by shared decoder config: ElectraConfig {
  "_name_or_path": "/project/lt200324-optmul/pluem/model/huggingface_electra-small-25000-no-grad-small",
  "add_cross_attention": true,
  "architectures": [
    "ElectraModel"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "embedding_size": 128,
  "hidden_act":

VisionEncoderDecoderModel(
  (encoder): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=False)
              (key): Linear(in_features=768, out_features=768, bias=False)
              (value): Linear(in_features=768, out_features=768, bias=False)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linea

In [3]:
class VisionElectraModel(nn.Module):
    def __init__(self, tokenizer, device='cpu'):
        super(VisionElectraModel, self).__init__()

        base_model = VisionEncoderDecoderModel.from_pretrained('openthaigpt/thai-trocr').to(device)
        self.encoder = base_model.encoder
        self.enc_to_dec_proj = base_model.enc_to_dec_proj
        self.tokenizer = tokenizer

        config = ElectraConfig.from_pretrained('openthaigpt/thai-trocr')
        config.is_decoder = True
        config.output_attentions   = True
        config.add_cross_attention = True

        self.decoder = VisionEncoderDecoderModel.from_pretrained('openthaigpt/thai-trocr').decoder.to(device)
        # self.decoder = ElectraForCausalLM.from_pretrained('openthaigpt/thai-trocr', config=config)

        self.decoder.config.pad_token_id = self.tokenizer.pad_token_id
        self.decoder.config.eos_token_id = self.tokenizer.eos_token_id
        self.decoder.config.bos_token_id = self.tokenizer.bos_token_id
        self.decoder.resize_token_embeddings(len(tokenizer.get_vocab()))
        self.decoder.config.max_length = 20

        self.vocab_size = self.decoder.config.vocab_size

        for param in self.encoder.parameters():
            param.requires_grad = False

        for param in self.decoder.parameters():
            param.requires_grad = True

        for name, param in self.decoder.electra.encoder.layer.named_parameters():
            if 'crossattention' in name: param.requires_grad = True

        self.loss_fnc = nn.CrossEntropyLoss(ignore_index=-100,).to(device)

    def shift_tokens_right(self, input_ids, pad_token_id, bos_token_id):

        # Create a shifted input with the bos_token at the start
        shifted_input_ids = input_ids.new_zeros(input_ids.shape)
        shifted_input_ids[:, 1:] = input_ids[:, :-1]
        shifted_input_ids[:, 0] = bos_token_id  # Add bos_token at the start
        
        # Replace any trailing positions with pad_token_id
        shifted_input_ids[shifted_input_ids == -100] = pad_token_id
        
        return shifted_input_ids

    def forward(self, pixel_values, labels, decoder_attention_mask):

        with torch.no_grad(): # if VIT is freezeing!
            patch_embeddings = self.encoder(pixel_values).last_hidden_state
            projected_features = self.enc_to_dec_proj(patch_embeddings)
        
        decoder_input_ids = self.shift_tokens_right(
                labels, self.decoder.config.pad_token_id, self.decoder.config.bos_token_id
        )

        outputs = self.decoder(
            input_ids=decoder_input_ids,
            labels=labels,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=projected_features,
            encoder_attention_mask=None,
        )

        logits = outputs.logits
        loss = self.loss_fnc(logits.view(-1, self.vocab_size), labels.view(-1))

        return outputs, loss
    
    def predict(self, pixel_values, max_length=20, device='cuda'):
        
        self.eval() 
        with torch.no_grad():

            patch_embeddings = self.encoder(pixel_values).last_hidden_state
            projected_features = self.enc_to_dec_proj(patch_embeddings)
            
            input_ids = torch.tensor([[self.decoder.config.bos_token_id]]).to(device)
            
            for _ in range(max_length):
                
                outputs = self.decoder(
                    input_ids=input_ids,
                    encoder_hidden_states=projected_features,
                    encoder_attention_mask=None,
                    use_cache=True  # Speeds up decoding by caching past key values !!!
                )
                
                logits = outputs.logits[:, -1, :]                           # last generated token [-1]
                next_token_id = torch.argmax(logits, dim=-1).unsqueeze(0)   # select token with the highest prob

                input_ids = torch.cat([input_ids, next_token_id], dim=1)
                if next_token_id.item() == self.decoder.config.eos_token_id:
                    break # Stop if EOS token is generated
                
        return input_ids

#### Load Model

In [6]:
import os
os.path.exists("/mnt/e/ALPR-CEPJ1-Experiments-logs/projects/best-w/ex4-e149l0.3560c0.2044w0.2971.pth")

True

In [7]:
tokenizer = CamembertTokenizer.from_pretrained("openthaigpt/thai-trocr")
model = VisionElectraModel(tokenizer=tokenizer, device='cpu')
model.load_state_dict(torch.load("/mnt/e/ALPR-CEPJ1-Experiments-logs/projects/best-w/ex4-e149l0.3560c0.2044w0.2971.pth", map_location=torch.device('cpu')))

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": false,
  "transformers_version": "4.46.2"
}

Config of the decoder: <class 'transformers.models.electra.modeling_electra.ElectraForCausalLM'> is overwritten by shared decoder config: ElectraConfig {
  "_name_or_path": "/project/lt200324-optmul/pluem/model/huggingface_electra-small-25000-no-grad-small",
  "add_cross_attention": true,
  "architectures": [
    "ElectraModel"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "embedding_size": 128,
  "hidden_act":

<All keys matched successfully>

In [1]:
import transformers
transformers.__version__

  from .autonotebook import tqdm as notebook_tqdm


'4.46.2'

In [7]:
%pip install transformers==4.46.2

Collecting transformers==4.46.2
  Downloading transformers-4.46.2-py3-none-any.whl.metadata (44 kB)
Downloading transformers-4.46.2-py3-none-any.whl (10.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.0/10.0 MB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0m
[?25hInstalling collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.46.3
    Uninstalling transformers-4.46.3:
      Successfully uninstalled transformers-4.46.3
Successfully installed transformers-4.46.2
Note: you may need to restart the kernel to use updated packages.


### Inference

In [None]:
# pixel_values = feature_extractor(cropped_character_images[1].convert('RGB'), return_tensors="pt").pixel_values

# std = mean= [0.5, 0.5, 0.5]
# plt.imshow(pixel_values.squeeze(dim=0).permute(1, 2, 0) * torch.tensor(std) + torch.tensor(mean))

# predict_ids = model.predict(pixel_values, device='cpu')
# predict_words = tokenizer.batch_decode(predict_ids, skip_special_tokens=True)
# predict_words