In [1]:
from PIL import Image
from transformers import NougatProcessor, VisionEncoderDecoderModel
import torch


In [2]:
ckpt = "/Users/rohan/3_Resources/ai_models/nougat-small"
processor = NougatProcessor.from_pretrained(ckpt)
model = VisionEncoderDecoderModel.from_pretrained(ckpt)

device = "cpu" if torch.cuda.is_available() else "cpu"
model.to(device)

VisionEncoderDecoderModel(
  (encoder): DonutSwinModel(
    (embeddings): DonutSwinEmbeddings(
      (patch_embeddings): DonutSwinPatchEmbeddings(
        (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): DonutSwinEncoder(
      (layers): ModuleList(
        (0): DonutSwinStage(
          (blocks): ModuleList(
            (0-1): 2 x DonutSwinLayer(
              (layernorm_before): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
              (attention): DonutSwinAttention(
                (self): DonutSwinSelfAttention(
                  (query): Linear(in_features=128, out_features=128, bias=True)
                  (key): Linear(in_features=128, out_features=128, bias=True)
                  (value): Linear(in_features=128, out_features=128, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
           

In [3]:
from transformers import VisionEncoderDecoderConfig

config = VisionEncoderDecoderConfig.from_pretrained(ckpt)
config

VisionEncoderDecoderConfig {
  "architectures": [
    "VisionEncoderDecoderModel"
  ],
  "decoder": {
    "_name_or_path": "",
    "activation_dropout": 0.0,
    "activation_function": "gelu",
    "add_cross_attention": true,
    "add_final_layer_norm": true,
    "architectures": null,
    "attention_dropout": 0.0,
    "bad_words_ids": null,
    "begin_suppress_tokens": null,
    "bos_token_id": 0,
    "chunk_size_feed_forward": 0,
    "classifier_dropout": 0.0,
    "cross_attention_hidden_size": null,
    "d_model": 1024,
    "decoder_attention_heads": 16,
    "decoder_ffn_dim": 4096,
    "decoder_layerdrop": 0.0,
    "decoder_layers": 4,
    "decoder_start_token_id": null,
    "diversity_penalty": 0.0,
    "do_sample": false,
    "dropout": 0.1,
    "early_stopping": false,
    "encoder_attention_heads": 16,
    "encoder_ffn_dim": 4096,
    "encoder_layerdrop": 0.0,
    "encoder_layers": 12,
    "encoder_no_repeat_ngram_size": 0,
    "eos_token_id": 2,
    "exponential_decay_length_p

In [4]:
from src.models import DonutSwinEncoder, NougatDecoder

encoder = DonutSwinEncoder(config.encoder)
decoder = NougatDecoder(
  config.decoder.d_model,
  config.decoder.decoder_layers,
  config.decoder.decoder_attention_heads,
  config.decoder.vocab_size,
  config.decoder.max_position_embeddings,
  config.decoder.scale_embedding
)

In [6]:
# mapping encoder
new_state_dict = {}
for k1, k2 in zip(model.encoder.state_dict().keys(), encoder.state_dict().keys()):
#   print(f'Mapping {k1:80s} -> {k2:50s}')
  new_state_dict[k2] = model.encoder.state_dict()[k1]

encoder.load_state_dict(new_state_dict)

<All keys matched successfully>

In [9]:
# mapping decoder
new_state_dict = {}
for k1, k2 in zip(model.decoder.state_dict().keys(), decoder.state_dict().keys()):
#   print(f'Mapping {k1:80s} -> {k2:50s}')
  new_state_dict[k2] = model.decoder.state_dict()[k1]

decoder.load_state_dict(new_state_dict)

<All keys matched successfully>

In [8]:
# testing encoder
with torch.no_grad():
  inp = torch.randn(1, 3, 896, 672)
  print(model.encoder(inp).last_hidden_state)
  print(encoder(inp))

tensor([[[-0.2550,  0.0571, -1.6363,  ...,  0.2944,  0.2292, -0.1080],
         [-0.2865,  0.0792, -0.1264,  ...,  0.3664, -0.0534,  0.2924],
         [ 0.5225, -0.0422,  1.4999,  ...,  0.0639, -0.1852, -0.1215],
         ...,
         [-0.0994,  0.1885,  3.9789,  ...,  0.4358,  0.1407,  0.8877],
         [-0.2664,  0.1803,  2.4243,  ...,  0.2370, -0.0108,  0.7536],
         [-0.4452,  0.0722, -4.8807,  ..., -0.1352,  0.0684, -0.2183]]])
tensor([[[-0.2550,  0.0571, -1.6363,  ...,  0.2944,  0.2292, -0.1080],
         [-0.2865,  0.0792, -0.1264,  ...,  0.3664, -0.0534,  0.2924],
         [ 0.5225, -0.0422,  1.4999,  ...,  0.0639, -0.1852, -0.1215],
         ...,
         [-0.0994,  0.1885,  3.9789,  ...,  0.4358,  0.1407,  0.8877],
         [-0.2664,  0.1803,  2.4243,  ...,  0.2370, -0.0108,  0.7536],
         [-0.4452,  0.0722, -4.8807,  ..., -0.1352,  0.0684, -0.2183]]])


In [11]:
model.eval()
decoder.eval()
with torch.no_grad():
  input_ids = torch.ones(1, 20).to(torch.long)
  encoder_hidden_states = torch.randn(1, 300, 1024)
  self_attention_mask = torch.zeros(1, 1, 20, 20)
  cross_attention_mask = torch.zeros(1, 1, 20, 300) # [bsz, 1, tgt_seq_len, src_seq_len]

  ks_cache = torch.empty(model.decoder.config.decoder_layers, 1, 0, model.decoder.config.d_model).to('cpu')
  vs_cache = torch.empty(model.decoder.config.decoder_layers, 1, 0, model.decoder.config.d_model).to('cpu')
  kc_cache = torch.empty(model.decoder.config.decoder_layers, 1, 0, model.decoder.config.d_model).to('cpu')
  vc_cache = torch.empty(model.decoder.config.decoder_layers, 1, 0, model.decoder.config.d_model).to('cpu')
  cache = (ks_cache, vs_cache, kc_cache, vc_cache)

  print(model.decoder(input_ids, encoder_hidden_states=encoder_hidden_states).logits)
  print('='*80)
  logits, cache = decoder(input_ids, self_attention_mask, encoder_hidden_states, cross_attention_mask, *cache)
  print(logits)

tensor([[[-1.0451, 23.8787, 15.1577,  ..., -1.9512, -1.3023, -4.4212],
         [-0.9088, 24.4422, 15.1216,  ..., -2.2604, -1.4398, -4.3751],
         [-0.8573, 24.7345, 14.8880,  ..., -2.3121, -1.4550, -4.3376],
         ...,
         [-1.3059, 26.4981, 12.7546,  ..., -2.4608, -1.5620, -4.0418],
         [-1.3484, 26.4824, 12.6804,  ..., -2.4934, -1.5732, -4.0317],
         [-1.3796, 26.4870, 12.5975,  ..., -2.5306, -1.5822, -4.0219]]])
tensor([[[-1.0451, 23.8787, 15.1577,  ..., -1.9512, -1.3023, -4.4212],
         [-0.9088, 24.4422, 15.1216,  ..., -2.2604, -1.4398, -4.3751],
         [-0.8573, 24.7345, 14.8880,  ..., -2.3121, -1.4550, -4.3376],
         ...,
         [-1.3059, 26.4981, 12.7546,  ..., -2.4608, -1.5620, -4.0418],
         [-1.3484, 26.4824, 12.6804,  ..., -2.4934, -1.5732, -4.0317],
         [-1.3796, 26.4870, 12.5975,  ..., -2.5306, -1.5822, -4.0219]]])


In [12]:
import os
torch.save(encoder.state_dict(), os.path.join(ckpt, 'encoder.bin'))
torch.save(decoder.state_dict(), os.path.join(ckpt, 'decoder.bin'))