In [1]:
import collections
import math
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import rich
from tqdm import tqdm

import datagen
import our_tokenizer

In [2]:
DATA_DIR = Path.cwd() / "data" 
names = [x.name for x in DATA_DIR.iterdir()]
names.sort(key=lambda x: x.rsplit(".", 1)[-1])
print(names)


['349_6_6.json', '80_3_6.json', '80_3_6.json.pkl', '349_6_6.json.pkl']


In [3]:
FILE_NAME = "349_6_6.json.pkl"
assert FILE_NAME in names, names

In [4]:
def is_sorted(l):
    return all(l[i] <= l[i + 1] for i in range(len(l) - 1))

def plot_lengths(lengths, x_subdiv=1, verbose=False):
    if verbose:
        print("Checking if lengths are sorted...")
    assert is_sorted(lengths)
    if verbose:
        print("Plotting...")
    plt.figure(figsize=(10, 10))
    plt.xticks(np.arange(int(np.ceil((np.max(lengths) + 1) / x_subdiv))) * x_subdiv)
    plt.yticks(np.linspace(0, 1, 21))
    plt.plot(lengths, np.linspace(0, 1, len(lengths)))
    plt.show()

In [5]:
tokenizer = our_tokenizer.ArithmeticTokenizer()
data, config = datagen.load_dataset(None, DATA_DIR / FILE_NAME)

dict_keys([1, 2, 3, 4, 5, 6])
dict_keys([1, 2, 3, 4, 5, 6])
Building nodes


100%|██████████| 300/300 [00:00<00:00, 148069.10it/s]<?, ?it/s]
100%|██████████| 200000/200000 [00:03<00:00, 64708.11it/s]
100%|██████████| 200000/200000 [01:16<00:00, 2625.99it/s] 6,  1.55s/it]
100%|██████████| 200000/200000 [00:12<00:00, 16027.50it/s]7, 32.64s/it]
100%|██████████| 200000/200000 [01:44<00:00, 1918.32it/s]50, 25.12s/it]
100%|██████████| 200000/200000 [02:06<00:00, 1576.40it/s]52, 52.64s/it]
Building nodes for train: 100%|██████████| 6/6 [05:22<00:00, 53.81s/it]
100%|██████████| 300/300 [00:00<00:00, 158096.65it/s]?, ?it/s]
100%|██████████| 200000/200000 [00:03<00:00, 66190.73it/s]
100%|██████████| 200000/200000 [00:06<00:00, 30416.42it/s],  1.51s/it]
100%|██████████| 200000/200000 [00:12<00:00, 16299.55it/s],  3.62s/it]
100%|██████████| 200000/200000 [02:23<00:00, 1397.80it/s]3,  6.85s/it]
100%|██████████| 200000/200000 [00:30<00:00, 6552.25it/s]4, 54.22s/it]
Building nodes for eval: 100%|██████████| 6/6 [03:15<00:00, 32.58s/it]


In [6]:
for split, levels in data.items():
    rich.print(f"[bold blue]{split} - Num points:")
    for name, level in levels.items():
        print(f"\t{split} - < {name} >: {len(level)}")


levels = collections.defaultdict(list)

for split, levels_per_split in data.items():
    for level, level_data in levels_per_split.items():
        levels[level].extend(level_data)


def filter_by_total_length(nodes, limit):
    good_nodes = []
    for node in nodes:
        if not len(tokenizer(node.get_oracle_str()[0], return_tensors=None, no_eos=True)) <= limit:
            continue
        good_nodes.append(node)
    return good_nodes

def filter_by_value_length(nodes, limit):
    output = []
    for node in nodes:
        if not all(len(tokenizer(v.get_value(), return_tensors=None, no_eos=True)) <= limit for v in datagen.get_all_desc(node)):
            continue
        output.append(node)
    return output


value_lens = {}
oracle_lens = {}
percentiles = [.85, .9, .95, .99]
BY_VALUE_QUANTILE = 99
BY_ORACLE_QUANTILE = 95

for level, root_nodes in levels.items():
    all_nodes = list(datagen.multiple_get_all_desc(level_data))
    value_lens[level] = [len(tokenizer(v.get_value(), return_tensors=None, no_eos=True)) for v in all_nodes]
    value_lens[level].sort()
    oracle_lens[level] = [len(tokenizer(node.get_oracle_str()[0], return_tensors=None, no_eos=True)) for node in root_nodes]
    oracle_lens[level].sort()
    print(f"< {level} > value lens percentiles:")
    for p in percentiles:
        print(f"\t{p * 100}%: {int(np.percentile(value_lens[level], p * 100))}")
    print(f"< {level} > oracle lens percentiles :")
    for p in percentiles:
        print(f"\t{p * 100}%: {int(np.percentile(oracle_lens[level], p * 100))}")


    # By value and by oracle length``
    
    value_limit = math.ceil(np.percentile(value_lens[level], BY_VALUE_QUANTILE))
    oracle_limit = math.ceil(np.percentile(oracle_lens[level], BY_ORACLE_QUANTILE))
    
    by_oracle_length = filter_by_total_length(root_nodes, oracle_limit)
    doubly_filtered = filter_by_value_length(by_oracle_length, value_limit)
    print(f"By value and by oracle length: (value = {value_limit}, oracle = {oracle_limit})")
    print(f"\t< {level} > count: {len(doubly_filtered)} / {len(root_nodes)} {len(doubly_filtered)/len(root_nodes):0.2%}")

    by_value_length = filter_by_value_length(root_nodes, value_limit)
    print(f"By value only: {value_limit}")
    print(f"\t< {level} > count: {len(by_value_length)} / {len(root_nodes)} {len(by_value_length)/len(root_nodes):0.2%}")

    print(f"By oracle only: {oracle_limit}")
    print(f"\t< {level} > count: {len(by_oracle_length)} / {len(root_nodes)} {len(by_oracle_length)/len(root_nodes):0.2%}")


assert False

for name, level_data in levels.items():
    assert isinstance(name, (int, str))
    rich.print(f"[bold blue]Level < {name} >:")
    plot_lengths(value_lens)
    plot_lengths(oracle_lens, 6)

	train - < 1 >: 300
	train - < 2 >: 200000
	train - < 3 >: 200000
	train - < 4 >: 200000
	train - < 5 >: 200000
	train - < 6 >: 200000


	eval - < 1 >: 300
	eval - < 2 >: 200000
	eval - < 3 >: 200000
	eval - < 4 >: 200000
	eval - < 5 >: 200000
	eval - < 6 >: 200000
< 1 > value lens percentiles:
	85.0%: 2
	90.0%: 2
	95.0%: 3
	99.0%: 5
< 1 > oracle lens percentiles :
	85.0%: 8
	90.0%: 8
	95.0%: 8
	99.0%: 8
By value and by oracle length: (value = 5, oracle = 8)
	< 1 > count: 600 / 600 100.00%
By value only: 5
	< 1 > count: 600 / 600 100.00%
By oracle only: 8
	< 1 > count: 600 / 600 100.00%
