In [19]:
import numpy as np
import lmdb
import json
from tqdm.notebook import tqdm, trange
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
from utils import round_float, is_nan
from collections import Counter
import torch
from torch import nn
from torch.utils.data import Dataset
from PIL import Image
import six

In [54]:
BOS_TOKEN = "<|BOS|>"
X_START = "<|x_start|>"
X_END = "<|x_end|>"
Y_START = "<|y_start|>"
Y_END = "<|y_end|>"


SEPARATOR_TOKENS = [
    BOS_TOKEN,
    X_START,
    X_END,
    Y_START,
    Y_END,
]

LINE_TOKEN = "<line>"
VERTICAL_BAR_TOKEN = "<vertical_bar>"
HORIZONTAL_BAR_TOKEN = "<horizontal_bar>"
SCATTER_TOKEN = "<scatter>"
DOT_TOKEN = "<dot>"

CHART_TYPE_TOKENS = [
    LINE_TOKEN,
    VERTICAL_BAR_TOKEN,
    HORIZONTAL_BAR_TOKEN,
    SCATTER_TOKEN,
    DOT_TOKEN,
]

new_tokens = SEPARATOR_TOKENS + CHART_TYPE_TOKENS

max_patches = 1024
max_length = 1024


In [57]:
processor = AutoProcessor.from_pretrained('../../../outputs/6_17')
processor.image_processor.size = {
    "height": 560,
    "width": 560,
}
processor.image_processor.is_vqa = False
processor.tokenizer.add_tokens(new_tokens)

model = Pix2StructForConditionalGeneration.from_pretrained('../../../outputs/6_17')
model.decoder.resize_token_embeddings(len(processor.tokenizer))
model.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids([BOS_TOKEN])[0]
model.config.text_config.is_decoder = True

In [58]:
model.save_pretrained('../../../outputs/6_17')
processor.save_pretrained('../../../outputs/6_17')

In [59]:
# env = lmdb.open('../../../data/data0004/lmdb', max_readers=32,
#                 readonly=True, lock=False, readahead=False, meminit=False)
# with env.begin(write=False) as txn:
#     n_samples = int(txn.get('num-samples'.encode()))
#     for idx in trange(n_samples):
#         # load json
#         label_key = f'label-{str(idx+1).zfill(8)}'.encode()
#         label = txn.get(label_key).decode('utf-8')
#         json_dict = json.loads(label)
#         data_str = data_series_to_string(json_dict)
#         token_ids = processor.tokenizer.encode(data_str)
#         tokens = processor.tokenizer.tokenize(data_str)
#         for token, token_id in zip(tokens, token_ids):
#             if token_id == unk_token_id:
#                 unk_tokens.append(token)
#     unk_counter = Counter(unk_tokens)

In [89]:
class MgaDataset(Dataset):
    def __init__(self, lmdb_dir, processor):
        self.processor = processor
        self.env = lmdb.open(str(lmdb_dir), max_readers=32,
                             readonly=True, lock=False, readahead=False, meminit=False)

    def _json_dict_to_gt_string(self, json_dict) -> str:
        """
        Args:
            json_dict (Dict[str, Any]): ターゲットのdict
        Returns:
            gt_string (str): 入力となるプロンプト
        """
        all_x, all_y = [], []

        for d in json_dict['data-series']:
            x = d["x"]
            y = d["y"]

            x = round_float(x)
            y = round_float(y)

            # Ignore nan values
            if is_nan(x) or is_nan(y):
                continue

            all_x.append(x)
            all_y.append(y)

        chart_type = f"<{json_dict['chart-type']}>"
        x_str = X_START + ";".join(list(map(str, all_x))) + X_END
        y_str = Y_START + ";".join(list(map(str, all_y))) + Y_END

        gt_string = BOS_TOKEN + chart_type + x_str + y_str

        return gt_string, list(map(str, all_x)), list(map(str, all_y))

    def __len__(self):
        with self.env.begin(write=False) as txn:
            n_samples = txn.get('num-samples'.encode())
        return n_samples

    def __getitem__(self, idx):
        """
        lmdbからidに一致したimageとlabelを取り出す

        image
            - byteをdecodeしてPIL.Image -> numpyにする

        label
            - byteからjson.loadsでdictにする
                keys: ['source', 'chart-type', 'plot-bb', 'text',
                    'axes', 'data-series', 'id', 'key_point']
            - 'data-series'から正解となるpromptを生成

        Returns:
            samples (Dict[str, Union[torch.Tensor, List[int], str]])
                pixel_values (torch.Tensor): 画像
                input_ids (List[int]): token idのリスト
                ids (str)
        """
        with self.env.begin(write=False) as txn:
            # load image
            img_key = f'image-{str(idx+1).zfill(8)}'.encode()
            imgbuf = txn.get(img_key)

            # load json
            label_key = f'label-{str(idx+1).zfill(8)}'.encode()
            label = txn.get(label_key).decode('utf-8')

        # label: ['source', 'chart-type', 'plot-bb', 'text', 'axes', 'data-series', 'id', 'key_point']
        json_dict = json.loads(label)

        # image
        buf = six.BytesIO()
        buf.write(imgbuf)
        buf.seek(0)
        image_arr = np.array(Image.open(buf).convert('RGB'))
        h, w, _ = image_arr.shape
        encoding = processor(
            images=image_arr,
            random_padding=True,
            add_special_tokens=True,
            max_patches=max_patches,
            return_tensors='pt'
        )
        # encoding = {k: v[0].squeeze() for k, v in encoding.items()}
        # encoding = {k: v for k, v in encoding.items()}

        gt_string, _, _ = self._json_dict_to_gt_string(json_dict)

        text_inputs = processor(
            text=gt_string,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
            add_special_tokens=True,
            max_length=max_length
        ).input_ids

        encoding['labels'] = text_inputs
        encoding['source'] = 0 if json_dict['source'] == 'generaeted' else 1
        encoding['id'] = json_dict['id']
        return encoding

In [90]:
ds = MgaDataset('../../../data/data0004/lmdb', processor)

In [91]:
data = ds.__getitem__(0)
labels = data['labels']
output = model(
    flattened_patches=data['flattened_patches'], #.unsqueeze(0),
    attention_mask=data['attention_mask'], # .unsqueeze(0)
    labels=labels
)

In [92]:
data['source']

1

In [93]:
output.logits.shape

torch.Size([1, 1024, 50354])

In [94]:
labels.shape

torch.Size([1, 1024])

In [95]:
labels

tensor([[50344, 50352, 50345,  ...,     0,     0,     0]])

In [97]:
processor.tokenizer.decode(labels[0, 2])

'<|x_start|>'

In [98]:
# output.

In [99]:
loss_fn = nn.CrossEntropyLoss()

In [100]:
loss = loss_fn(output.logits.reshape(-1, model.decoder.config.vocab_size), labels.reshape(-1))
chart_type_loss = loss_fn(output.logits.reshape(-1, model.decoder.config.vocab_size)[1:2, :], labels.reshape(-1)[1:2])

In [101]:
output.logits.reshape(-1, model.decoder.config.vocab_size).shape

torch.Size([1024, 50354])

In [102]:
chart_type_loss

tensor(4.6274, grad_fn=<NllLossBackward0>)

In [105]:
processor.tokenizer.decode(labels.reshape(-1)[0:7])

'<|BOS|><scatter><|x_start|> 1.0'

In [107]:
output.logits.shape

torch.Size([1, 1024, 50354])

In [285]:
class CustomLoss(nn.Module):
    def __init__(self, extracted_weight=100.):
        super().__init__()
        self.log_softmax = nn.LogSoftmax(dim=1)
        self.extracted_weight = extracted_weight
    
    def forward(self, input, target, source):
        '''
            input: (bs, length, vocab_size)
            target: (bs, length)
            source: (bs)
        '''

        bs, l, vs = input.shape
        input = input.reshape(-1, vs)
        target = target.reshape(-1)
        source = torch.tile(source, (1, l)).reshape(-1)
        weight = self.extracted_weight * source + (1. - source)

        ls = self.log_softmax(input)
        loss_per_bs = -1 * ls.index_select(-1, target).diag() # (bs * len)
        return torch.mean(loss_per_bs * weight)

In [286]:
# pred: (bs, len, voc_size)
# target: (bs, len)

# source: (bs, ) => (bs, len)

# loss_input: (bs * len, voc_size)
# loss_target: (bs * len)

# loss_per_bs: (bs)

# bs = 2, len=3, voc_size=4

In [287]:
pred = torch.randn((2, 3, 4)).float()
target = torch.tensor([
    [1, 0, 2],
    [0, 0, 3]
]).long()
source = torch.tensor([1, 1]).long()

In [288]:
# pred = torch.tensor([[-100, -0.2, 0.5], [0.8, -0.2, 0.5]]).float()
# target = torch.tensor([0, 2]).long()

loss_fn = nn.CrossEntropyLoss()
print('ce loss', loss_fn(pred.reshape(-1, 4), target.reshape(-1)))
custom_loss_fn = CustomLoss()
print('custom loss', custom_loss_fn(pred, target, source))

ce loss tensor(1.5734)
custom loss tensor(157.3420)


In [162]:
pred.shape

torch.Size([1, 3])

tensor(1.7461)

tensor([-1.7461])

In [129]:
s = nn.LogSoftmax(dim=1)

In [130]:
i = torch.randn(2, 3)
s(i)

tensor([[-0.7632, -2.1829, -0.8648],
        [-1.5033, -1.8776, -0.4706]])

In [305]:
a = list(np.arange(100))

In [319]:
import random



[32, 41, 25, 30, 24, 14, 37, 60, 71, 78]