In [1]:
import lmdb
import json
from tqdm.notebook import tqdm, trange
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
from utils import round_float, is_nan
from collections import Counter


In [2]:
BOS_TOKEN = "<|BOS|>"
START = "<start>"
END = "<end>"


SEPARATOR_TOKENS = [
    BOS_TOKEN,
    START,
    END
]

LINE_TOKEN = "<line>"
VERTICAL_BAR_TOKEN = "<vertical_bar>"
HORIZONTAL_BAR_TOKEN = "<horizontal_bar>"
SCATTER_TOKEN = "<scatter>"
DOT_TOKEN = "<dot>"

CHART_TYPE_TOKENS = [
    LINE_TOKEN,
    VERTICAL_BAR_TOKEN,
    HORIZONTAL_BAR_TOKEN,
    SCATTER_TOKEN,
    DOT_TOKEN,
]

new_tokens = SEPARATOR_TOKENS + CHART_TYPE_TOKENS


In [3]:
def data_series_to_string(json_dict):
    """
    Args:
        json_dict (Dict[str, Any]): ターゲットのdict
    Returns:
        gt_string (str): 入力となるプロンプト
    """
    all_x, all_y = [], []

    for d in json_dict['data-series']:
        x = d["x"]
        y = d["y"]

        x = round_float(x)
        y = round_float(y)

        # Ignore nan values
        if is_nan(x) or is_nan(y):
            continue

        all_x.append(x)
        all_y.append(y)

    chart_type = f"<{json_dict['chart-type']}>"
    data_str = \
        START + \
        '|'.join([f'{x}|{y}' for x, y in zip(all_x, all_y)]) \
        + END

    gt_string = BOS_TOKEN + chart_type + data_str

    return gt_string

In [4]:
processor = AutoProcessor.from_pretrained('google/matcha-base')

In [5]:
processor.tokenizer.add_tokens(new_tokens)

8

### check token length

In [13]:
token_lengthes = []
chart_types = []
sources = []

env = lmdb.open('../../../data/data0004/lmdb', max_readers=32,
                readonly=True, lock=False, readahead=False, meminit=False)
with env.begin(write=False) as txn:
    n_samples = int(txn.get('num-samples'.encode()))
    for idx in trange(n_samples):
        # load json
        label_key = f'label-{str(idx+1).zfill(8)}'.encode()
        label = txn.get(label_key).decode('utf-8')
        json_dict = json.loads(label)
        data_str = data_series_to_string(json_dict)
        token_ids = processor.tokenizer.encode(data_str)
        token_lengthes.append(len(token_ids))
        chart_types.append(json_dict['chart-type'])
        sources.append(json_dict['source'])

  0%|          | 0/60578 [00:00<?, ?it/s]

In [20]:
longs = [len > 512 for len in token_lengthes]
long_source = [source for i, source in enumerate(sources) if longs[i]]

In [21]:
Counter(long_source)

Counter({'generated': 224, 'extracted': 17})

### check unk

In [6]:
unk_tokens = []
unk_token_id = processor.tokenizer.unk_token_id

In [52]:
env = lmdb.open('../../../data/data0004/lmdb', max_readers=32,
                readonly=True, lock=False, readahead=False, meminit=False)
with env.begin(write=False) as txn:
    n_samples = int(txn.get('num-samples'.encode()))
    for idx in trange(n_samples):
        # load json
        label_key = f'label-{str(idx+1).zfill(8)}'.encode()
        label = txn.get(label_key).decode('utf-8')
        json_dict = json.loads(label)
        data_str = data_series_to_string(json_dict)
        token_ids = processor.tokenizer.encode(data_str)
        tokens = processor.tokenizer.tokenize(data_str)
        for token, token_id in zip(tokens, token_ids):
            if token_id == unk_token_id:
                unk_tokens.append(token)
    unk_counter = Counter(unk_tokens)

  0%|          | 0/60578 [00:00<?, ?it/s]

In [55]:
unk_counter

Counter({'<unk>': 157, '\n': 284, '\n\n': 28, 'ދ': 2})

### decide separate character

In [82]:
processor.tokenizer.encode(';')

[273, 324, 1]

In [83]:
check_token_id = 324 # processor.tokenizer.encode(',')
check_tokens = []

with env.begin(write=False) as txn:
    n_samples = int(txn.get('num-samples'.encode()))
    for idx in trange(n_samples):
        # load json
        label_key = f'label-{str(idx+1).zfill(8)}'.encode()
        label = txn.get(label_key).decode('utf-8')
        json_dict = json.loads(label)
        data_str = data_series_to_string(json_dict)

        token_ids = processor.tokenizer.encode(data_str)
        tokens = processor.tokenizer.tokenize(data_str)
        for token, token_id in zip(tokens, token_ids):
            if token_id == check_token_id:
                check_tokens.append(token)
                break
    check_counter = Counter(check_tokens)

  0%|          | 0/60578 [00:00<?, ?it/s]

In [84]:
check_counter

Counter()

In [68]:
processor.tokenizer('').input_ids

[273, 285, 275, 289, 1]

In [77]:
processor.tokenizer.decode(789)

'|'

In [22]:
s = 

In [23]:
s.split('|')

["Would grant Gov't/Dail too much control"]

In [25]:
import re

a


In [28]:
re.search('|', "Would grant Gov't/Dail too much control")

<re.Match object; span=(0, 0), match=''>

In [29]:
bool(re.search('|', '324'))

True

In [32]:
'|' in '324|'

True