In [1]:
from train import get_pretraining_datasets
from magma import Magma
import torch

In [2]:
model = Magma(
        'configs/MAGMA_v1_layoutlmv3.yml',
        device=torch.device("cpu")
    )  # for finetuning one might want to load the model via Magma.from_checkpoint(...) here
tokenizer, config, transforms = model.tokenizer, model.config, model.transforms

Loading OPT language model...
From facebook/galactica-6.7b


In [54]:
from transformers import AutoTokenizer, OPTForCausalLM

tokenizer = AutoTokenizer.from_pretrained("facebook/galactica-6.7b")

In [50]:
print(tokenizer.encode('[START_REF]'))
print(tokenizer.encode('my name is [START_REF] dog'))

[4]
[9444, 4014, 343, 243, 4, 7214]


In [53]:
tokenizer

PreTrainedTokenizer(name_or_path='facebook/opt-6.7b', vocab_size=50265, model_max_len=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_token': AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True)})

In [55]:
tokenizer.add_special_tokens(
        {
            "cls_token": "<|image|>",
            "pad_token": "</s>",
            "eos_token": "</s>",
            "bos_token": "</s>",
            "unk_token": "</s>",
        }
    )

1

In [56]:
tokenizer

PreTrainedTokenizerFast(name_or_path='facebook/galactica-6.7b', vocab_size=50000, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '</s>', 'eos_token': '</s>', 'unk_token': '</s>', 'pad_token': '</s>', 'cls_token': '<|image|>'})

In [57]:
print(tokenizer.decode([9444, 4014, 4, ]))

my name[START_REF]


In [3]:
train_dataset, eval_dataset = get_pretraining_datasets(
        config, tokenizer, transforms
)

loading dataset paths from /home/duan/magma/dataset/scicap/train: 333442it [00:01, 215052.36it/s]
loading dataset from /home/duan/magma/dataset/scicap/train: 100%|██████████| 333442/333442 [01:27<00:00, 3789.68it/s] 
loading dataset paths from /home/duan/magma/dataset/scicap/val: 41680it [00:00, 434322.19it/s]
loading dataset from /home/duan/magma/dataset/scicap/val: 100%|██████████| 41680/41680 [01:04<00:00, 647.44it/s]  


Loaded train dataset with 333442 samples
Loaded eval dataset with 41680 samples


In [32]:
from transformers.tokenization_utils_base import BatchEncoding
from typing import List, Tuple, Generator

def collate_fn(batch_data: List[Tuple[torch.Tensor, torch.Tensor]], seq_len=2048):
    if isinstance(batch_data[0][0], BatchEncoding):
        batch_captions = [i[1] for i in batch_data]
        batch_images = [i[0] for i in batch_data]
        batch_encodings = batch_images[0]
        for image_encodeing in batch_images[1:]:
            for k in batch_encodings.keys():
                #print(type(batch_encodings), type(image_encodeing), k)
                batch_encodings[k] = torch.cat((batch_encodings[k], image_encodeing[k]), dim=0) 
        return batch_encodings, torch.cat([i[:, :seq_len] for i in batch_captions])
    else:
        all_images, all_captions = list(
            zip(*batch_data)
        )  # [(img1, caption1), (img2, caption2), ... ] -> [(img1, img2, ... ), (caption1, caption2, ... )]
        return torch.cat(all_images), torch.cat([i[:, :seq_len] for i in all_captions])

from torch.utils.data import DataLoader
from magma.utils import cycle
from functools import partial
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=partial(collate_fn, seq_len=model.seq_len))


In [33]:
bboxes = []
for img, caption in train_loader:
    print(len(caption))
    bboxes+=img['bbox']

32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32


In [None]:
all_bbox = torch.cat(bboxes, dim=0)

In [None]:
torch.max(all_bbox, dim=0)