In [1]:
import collections
from pathlib import Path

import h5py
import numpy as np
from tqdm.notebook import tqdm
import transformers

In [2]:
tokenizer = transformers.AutoTokenizer.from_pretrained("distilgpt2")

In [3]:
data_root = Path("../data/basic_arithmetic/80_3_6_200000/")

with h5py.File(data_root / "training.h5") as fin:
    training = {k: fin[k][:] for k, v in fin.items()}

with h5py.File(data_root / "validation.h5") as fin:
    validation = {k: fin[k][:] for k, v in fin.items()}

In [4]:
print(training.keys())

dict_keys(['input', 'input_and_scratchpad_with_value', 'input_and_scratchpad_with_value_attention_mask', 'input_and_scratchpad_with_value_text', 'input_attention_mask', 'input_text', 'scratchpad', 'scratchpad_attention_mask', 'scratchpad_text', 'scratchpad_with_value', 'scratchpad_with_value_attention_mask', 'scratchpad_with_value_text', 'value', 'value_attention_mask', 'value_text'])


In [5]:
tr_labels = training["scratchpad_with_value"]
va_labels = validation["scratchpad_with_value"]
tr_input_ids = training["input"]
print(tr_input_ids.shape)
va_input_ids = validation["input"]
print(va_input_ids.shape)

(499900, 46)
(499603, 47)


In [6]:
def average_length_labels(labels):
    return np.sum((labels != -100) & (labels != tokenizer.eos_token_id), axis=1).mean()

def compute_level(input_ids):
    decoded = [tokenizer.decode(x) for x in tqdm(input_ids, desc="decoding")]
    max_levels = []
    for decoded_ in tqdm(decoded, desc="counting levels"):
        max_level = 0
        level = 0
        for c in decoded_:
            if c == "(":
                level += 1
                max_level = max(max_level, level)
            elif c == ")":
                level -= 1
        max_levels.append(max_level)

    per_level = collections.defaultdict(list)
    for level, ids in zip(max_levels, input_ids):
        per_level[level].append(ids)
    
    for level, ids in per_level.items():
        per_level[level] = np.array(ids)

    lengths_per_level = {k: average_length_labels(per_level[k]) for k in sorted(per_level)}
    sample_qty_per_level = dict(sorted(collections.Counter(max_levels).items(), key=lambda x: x[0]))

    print("Sorting.")
    return sample_qty_per_level, lengths_per_level

In [7]:
print(average_length_labels(tr_labels))
print(average_length_labels(va_labels))

36.834062812562514
36.820753678420665


In [8]:
sample_qty_per_level_tr, lengths_per_level_tr = compute_level(tr_input_ids)
sample_qty_per_level_va, lengths_per_level_va = compute_level(va_input_ids)
print(f"{sample_qty_per_level_tr = }")
print(f"{sample_qty_per_level_va = }")
print(f"{lengths_per_level_tr = }")
print(f"{lengths_per_level_va = }")

decoding:   0%|          | 0/499900 [00:00<?, ?it/s]

counting levels:   0%|          | 0/499900 [00:00<?, ?it/s]