In [1]:
import os
import torch

from src.models import DonutSwinEncoder, NougatDecoder
from transformers import NougatProcessor, VisionEncoderDecoderConfig

ckpt = "/Users/rohan/3_Resources/ai_models/nougat-small"
config = VisionEncoderDecoderConfig.from_pretrained(ckpt)
processor = NougatProcessor.from_pretrained(ckpt)

encoder = DonutSwinEncoder(config.encoder)
decoder = NougatDecoder(
  embed_dim=config.decoder.d_model,
  num_layers=config.decoder.decoder_layers,
  vocab_size=config.decoder.vocab_size,
  scale_embedding=config.decoder.scale_embedding,
  num_heads=config.decoder.decoder_attention_heads,
  max_len=config.decoder.max_position_embeddings,
)

encoder.load_state_dict(torch.load(os.path.join(ckpt, 'encoder.bin')))
decoder.load_state_dict(torch.load(os.path.join(ckpt, 'decoder.bin')))

encoder.eval()
decoder.eval()

NougatDecoder(
  (decoder): MBartDecoder(
    (word_embeddings): Embedding(50000, 1024)
    (position_embeddings): Embedding(3586, 1024)
    (layers): ModuleList(
      (0-3): 4 x MBartLayer(
        (self_attn): MBartSelfAttention(
          (key): Linear(in_features=1024, out_features=1024, bias=True)
          (value): Linear(in_features=1024, out_features=1024, bias=True)
          (query): Linear(in_features=1024, out_features=1024, bias=True)
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (ln1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (cross_attn): MBartSelfAttention(
          (key): Linear(in_features=1024, out_features=1024, bias=True)
          (value): Linear(in_features=1024, out_features=1024, bias=True)
          (query): Linear(in_features=1024, out_features=1024, bias=True)
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (ln2): LayerNorm((1024,), eps=1e-

In [2]:
from PIL import Image


In [3]:
image = Image.open('tmp.png').convert('RGB')
pixel_values = processor([image, image], return_tensors='pt').pixel_values
encoder_hidden_state = encoder(pixel_values)
encoder_hidden_state.shape

torch.Size([2, 588, 1024])

In [4]:
# generation
import torch

bz = encoder_hidden_state.shape[0]
input_ids = [[0] for _ in range(bz)]  # start token
output_ids = [[0] for _ in range(bz)]  # will be appending the generated ids to this

n_layers = config.decoder.decoder_layers
embed_dim = config.decoder.d_model

ks_cache = torch.empty(n_layers, bz, 0, embed_dim)
vs_cache = torch.empty(n_layers, bz, 0, embed_dim)
kc_cache = torch.empty(n_layers, bz, 0, embed_dim)
vc_cache = torch.empty(n_layers, bz, 0, embed_dim)

cache = (ks_cache, vs_cache, kc_cache, vc_cache)

max_len = 10
inp = torch.tensor(input_ids)
while True and max_len > 0:
  n_inp_tokens = len(input_ids[0])
  cross_attention_mask = torch.zeros(1, 1, n_inp_tokens, encoder_hidden_state.shape[1])  # probably will broadcast
  self_attention_mask = torch.zeros(1, 1, n_inp_tokens, n_inp_tokens)
  
  with torch.no_grad():
#     inp = torch.tensor(input_ids)
    kv_logits, cache = decoder(
      inp,
      self_attention_mask,
      encoder_hidden_state,
      cross_attention_mask,
      *cache
    )
    
  next_token_ids = kv_logits[:, -1, :].argmax(dim=-1)
  break_flag = True
  for i, x in enumerate(next_token_ids):
    if x != 2 or x !=1 : break_flag=False 
    output_ids[i].append(x)
  inp = next_token_ids.unsqueeze(1)
#   print(inp.shape)
  max_len -= 1
  print(max_len)
  if break_flag: break

torch.Size([2, 1])
9
torch.Size([2, 1])
8
torch.Size([2, 1])
7
torch.Size([2, 1])
6
torch.Size([2, 1])
5
torch.Size([2, 1])
4
torch.Size([2, 1])
3
torch.Size([2, 1])
2
torch.Size([2, 1])
1
torch.Size([2, 1])
0


In [20]:
pred = processor.tokenizer.decode(output_ids[0])
print(pred)

<s>can be controlled with text, such as style or


In [23]:
all(next_token_ids == 457) 

True