In [56]:
import tiktoken
import torch
import torch.nn.functional as F

tokenizer = tiktoken.get_encoding("gpt2")
tokenizer.encode("hello")

[31373]

In [2]:
tokenizer._pat_str

"'(?:[sdmt]|ll|ve|re)| ?\\p{L}++| ?\\p{N}++| ?[^\\s\\p{L}\\p{N}]++|\\s++$|\\s+(?!\\S)|\\s"

In [3]:
special_tokens = {"<image>": tokenizer.n_vocab+1}
tokenizer_modified = tiktoken.Encoding(
    name="gpt2_with_image",
    pat_str=tokenizer._pat_str,
    mergeable_ranks=tokenizer._mergeable_ranks,
    special_tokens={**tokenizer._special_tokens, **special_tokens}
)

In [77]:
def text_to_token_ids(texts, tokenizer, device="cpu"):
    # return torch.tensor(tokenizer.encode(text, allowed_special="<|endoftext|>")).unsqueeze(0)
    if type(texts) == list:
        encodings = []
        for text in texts:
            token_ids = torch.tensor(
                        tokenizer.encode(
                                text,
                                allowed_special={"<|endoftext|>", "<image>"}
                            ),
                            
                    device=device).unsqueeze(0)
            encodings.append(token_ids)

        max_len = max(e.numel() for e in encodings)
        # import pdb;
        # pdb.set_trace()
        encodings_cat = torch.cat([
            F.pad(e, (0, max_len - e.numel()), value=50256)
            for e in encodings
        ], dim=0)


        return encodings_cat
    
    else:
        return torch.tensor(
                        tokenizer.encode(
                                texts,
                                allowed_special={"<|endoftext|>", "<image>"}
                            ),
                    device=device).unsqueeze(0)
        

def token_ids_to_text(token_ids, tokenizer):
    flat = token_ids.squeeze(0).cpu()
    return tokenizer.decode(flat.tolist())
    
encoded = text_to_token_ids("hello hi __hi h...", tokenizer)
token_ids_to_text(encoded, tokenizer)

'hello hi __hi h...'

In [5]:
vocab_size = tokenizer_modified.n_vocab
vocab_size

50259

In [66]:
image_token_id = text_to_token_ids("<image>", tokenizer_modified)
image_token_id

tensor([[50258]])

In [78]:
sample = torch.rand(2, 273, 1280)
text_embeds = torch.rand(2, 768, 1280)

batch_size = sample.shape[0]
texs = ["Extract <image> all text from this document.", "hello"] 
input_ids = text_to_token_ids(texs, tokenizer_modified)
input_ids

tensor([[11627,   974,   220, 50258,   477,  2420,   422,   428,  3188,    13],
        [31373, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256]])

In [38]:
image_token_mask = (image_token_id == input_ids)
image_token_mask.shape

torch.Size([2, 10])

In [39]:
image_token_mask

tensor([[False, False, False,  True, False, False, False, False, False, False],
        [False, False, False,  True, False, False, False, False, False, False]])

In [43]:
b = 0
image_positions = torch.where(image_token_mask[b])[0]
img_pos = image_positions.squeeze().item()
img_pos

3

In [48]:
before = text_embeds[b, :img_pos]
after = text_embeds[b, img_pos+1:]

merged = torch.cat((before, sample[b] ,after), dim = 0)
merged.shape

torch.Size([1040, 1280])

In [81]:
image_token_id = text_to_token_ids("<image>", tokenizer_modified)
texs = ["Extract <image> all text from this document.", "hello <image>"] 
input_ids = text_to_token_ids(texs, tokenizer_modified)

final_embeds = []
for batch in range(batch_size):
    image_token_mask = (image_token_id == input_ids)
    image_positions = torch.where(image_token_mask[batch])[0]
    img_pos = image_positions.squeeze().item()

    before = text_embeds[batch, :img_pos]
    after = text_embeds[batch, img_pos+1:]

    merged = torch.cat((before, sample[batch] ,after), dim = 0)
    final_embeds.append(merged)

max_len = max(e.shape[0] for e in final_embeds)

padded_embeds = torch.stack([
    F.pad(e, (0, 0, 0, max_len - e.shape[0]), value=50256)
    for e in final_embeds
])

padded_embeds.shape

torch.Size([2, 1040, 1280])

In [62]:
max_len = max(e.shape[0] for e in final_embeds)

padded_embeds = torch.stack([
    F.pad(e, (0, 0, 0, max_len - e.shape[0]), value=50256)
    for e in final_embeds
])

padded_embeds.shape

torch.Size([2, 1040, 1280])

In [82]:
attention_mask = torch.zeros(batch_size, max_len)
for i, e in enumerate(final_embeds):
    attention_mask[i, :e.shape[0]] = 1

In [94]:
sum(attention_mask[0] != 1.)

tensor(0)