In [1]:
print("Doing imports")
import collections
import math
from pathlib import Path
import itertools
import subprocess
import time
import concurrent.futures
import multiprocessing

import beartype
from beartype.typing import *
import matplotlib.pyplot as plt
import more_itertools
import numpy as np
# import ujson as json
import orjson as json
import jsonlines as jsonl
import rich
from tqdm import tqdm


print("Done with other imports")

import our_tokenizer
# import data_generation_arithmetic

print("Done with torch")


def cmd(command: list[str]) -> list[str]:
    return subprocess.check_output(command).decode("utf-8").strip().split("\n")

def only_one(it: Iterable):
    iterated = iter(it)
    good = next(iterated)
    for bad in iterated:
        raise ValueError("Expected only one item, got more than one.")
    return good

def check_len(it: Sequence, expected_len: int) -> Sequence:
    if not len(it) == expected_len:
        raise ValueError(f"Expected {expected_len} items, got {len(it)}.")
    return it

def count_lines(path: Path) -> int:
    return int(check_len(only_one(cmd(["wc", "-l", str(path)])).split(), 2)[0])

def count_lines_list(paths: Path) -> int:
    with concurrent.futures.ThreadPoolExecutor() as tp:
        futures = {file: tp.submit(lambda: count_lines(file)) for file in paths}
        return {file: future.result() for file, future in futures.items()}
    

Doing imports
Done with other imports
Done with torch


In [2]:
#########################################################################################################
# List the output files of an experiment
#########################################################################################################
TARGET_DIR = Path("log_results/oracle")

directories = list(TARGET_DIR.iterdir())
active = []
for file in tqdm(directories):
    target = file / "predictions.jsonl"
    if target.exists():
        active.append(target)
rich.print("Directories:")
rich.print(active)


#########################################################################################################
# Select the files that we will use for the analysis
#########################################################################################################
MIN_LENGTH = 20

rich.print("Doing count line pre-filtration")
start = time.perf_counter()
lengths_active = count_lines_list(active)
rich.print(f"Took {time.perf_counter() - start} seconds")

selected = [file for file, length in tqdm(lengths_active.items(), desc="filtering") if length >= MIN_LENGTH]
new_lengths = [lengths_active[file] for file in tqdm(selected, desc="count lines post-filtration")]
min_length = min(new_lengths)

rich.print("Length counter pre-filtration", collections.Counter(lengths_active))
rich.print("Number of files post filtration:", len(selected))
rich.print("Length counter post filtration:", collections.Counter(new_lengths))
rich.print("Shortest file post filtration:", min_length)

100%|██████████| 33/33 [00:00<00:00, 1796.25it/s]


filtering: 100%|██████████| 17/17 [00:00<00:00, 68298.05it/s]
count lines post-filtration: 100%|██████████| 17/17 [00:00<00:00, 71805.81it/s]


In [3]:
start = time.perf_counter()
files_to_epochs = {}

def fn(file: Path):
    with jsonl.open(file) as reader:
        output = []
        for line in itertools.islice(reader, min_length):
            output.append(line)
    return file, output

with multiprocessing.Pool(8) as pp:
    for i, (file, output) in enumerate(pp.map(fn, selected)):
        files_to_epochs[file] = output

print(time.perf_counter() - start)

2778.983425117098


In [6]:
def assert_fn(condition, message):
    assert condition, message


# Check if lines repeat themselves 
def check_all_same_keys(path, obj, keys_seen: set[str]):
    epochs_seen = set()

    for i, epoch_content in enumerate(tqdm(obj)):

        if i == 0:
            continue 

        if epoch_content["epoch"] in epochs_seen:
            print(f"file {path} already seen epoch {epoch_content['epoch']}")
        
        epochs_seen.add(epoch_content["epoch"])
        
        keys = epoch_content["results"].keys()
        
        [assert_fn(isinstance(x, str), type(x)) for x in keys]
        
        if keys_seen:
            for key in keys:
                assert key.strip() in keys_seen, key
        else:
            keys_seen.update(key.strip() for key in keys)

keys = set()
for file, obj in tqdm(files_to_epochs.items()):
    check_all_same_keys(file, obj, keys)
    
print("Looks all good.")


  0%|          | 0/17 [00:00<?, ?it/s]

file log_results/oracle/oracle_0/predictions.jsonl already seen epoch 19


100%|██████████| 55/55 [00:01<00:00, 37.48it/s]
  6%|▌         | 1/17 [00:01<00:23,  1.47s/it]

file log_results/oracle/oracle_1/predictions.jsonl already seen epoch 19


100%|██████████| 55/55 [00:01<00:00, 39.02it/s]
100%|██████████| 55/55 [00:01<00:00, 35.78it/s]
100%|██████████| 55/55 [00:01<00:00, 42.37it/s]
100%|██████████| 55/55 [00:01<00:00, 40.79it/s]
100%|██████████| 55/55 [00:01<00:00, 37.00it/s]
100%|██████████| 55/55 [00:01<00:00, 38.81it/s]
100%|██████████| 55/55 [00:01<00:00, 40.51it/s]
100%|██████████| 55/55 [00:01<00:00, 39.74it/s]
100%|██████████| 55/55 [00:01<00:00, 38.32it/s]
100%|██████████| 55/55 [00:01<00:00, 39.49it/s]
100%|██████████| 55/55 [00:01<00:00, 40.40it/s]
100%|██████████| 55/55 [00:01<00:00, 40.18it/s]
100%|██████████| 55/55 [00:01<00:00, 37.34it/s]
100%|██████████| 55/55 [00:01<00:00, 36.25it/s]
100%|██████████| 55/55 [00:01<00:00, 36.73it/s]
 94%|█████████▍| 16/17 [00:22<00:01,  1.46s/it]

file log_results/oracle/oracle_16/predictions.jsonl already seen epoch 8


100%|██████████| 55/55 [00:01<00:00, 36.32it/s]
100%|██████████| 17/17 [00:24<00:00,  1.43s/it]

Looks all good.





In [7]:
#########################################################################################################
# Check how well the models agree.
#########################################################################################################

file_iterator = {file: iter(epoch_list) for file, epoch_list in files_to_epochs.items()}
"""
epoch:
- {results: {input_eqn_str: {is_freeform_bool: {'per_batch_mode': tensor_list}}}}
- 
"""

tokenizer = our_tokenizer.ArithmeticTokenizer()
equal_idx = tokenizer.token_to_idx["="]

def find_last_equal(line: list[int]):
    for i in range(len(line) - 1, -1, -1):
        if line[i] == equal_idx:
            return i
    return None

per_epoch_agreement = {}
for epoch_num in tqdm(range(min_length)):
    eqs_to_values = collections.defaultdict(lambda: collections.defaultdict(list))
    for file_idx, (path, per_file) in enumerate(file_iterator.items()):
        per_file_per_epoch = next(per_file)
        if epoch_num == 0:
            continue

        assert per_file_per_epoch["epoch"] == epoch_num - 1, (
            per_file_per_epoch["epoch"], epoch_num - 1
        )
        
        for key, value in per_file_per_epoch["results"].items():
            prediction = value["True"]["per_batch"]
            maybe_last_equal = find_last_equal(prediction)
            if maybe_last_equal is not None:
                numerical_prediction = prediction[maybe_last_equal + 1:]
                numerical_prediction = tokenizer.decode(numerical_prediction, ignore_special_symbols=True)
                numerical_prediction = numerical_prediction.replace(" ", "")
                numerical_prediction = numerical_prediction.replace(")", "")
                eqs_to_values[key][path] = numerical_prediction.strip()
            else:
                eqs_to_values[key][path] = None                
    
    if epoch_num == 0:
        continue

    agreement = dict()

    for k, v in eqs_to_values.items():
        counter = collections.Counter(v.values())
        if v:
            agreement[k] = counter.most_common(1)[0][1] / len(v)

    mean = np.mean(np.fromiter(agreement.values(), dtype=np.float64))
    rich.print(f"[bold]Epoch {epoch_num}:[/] {mean:0.2%}")
    per_epoch_agreement[epoch_num] = agreement
    


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


  4%|▎         | 2/55 [00:22<10:02, 11.37s/it]

  5%|▌         | 3/55 [00:45<14:06, 16.27s/it]

In [None]:
# for k, v in counts.items():
#     print("epoch", k, collections.Counter(v.values()))

for k, v in counts_levels.items():
    print("file", k, dict(sorted(v.items(), key=lambda x: x[0])))


file 0 {1: 300, 2: 10000, 3: 10000, 4: 10000, 5: 10000, 6: 10000}
file 1 {1: 300, 2: 10000, 3: 10000, 4: 10000, 5: 10000, 6: 10000}
file 2 {1: 300, 2: 10000, 3: 10000, 4: 10000, 5: 10000, 6: 10000}
file 3 {1: 300, 2: 10000, 3: 10000, 4: 10000, 5: 10000, 6: 10000}
file 4 {1: 300, 2: 10000, 3: 10000, 4: 10000, 5: 10000, 6: 10000}
file 5 {1: 300, 2: 10000, 3: 10000, 4: 10000, 5: 10000, 6: 10000}
file 6 {1: 300, 2: 10000, 3: 10000, 4: 10000, 5: 10000, 6: 10000}
file 7 {1: 300, 2: 10000, 3: 10000, 4: 10000, 5: 10000, 6: 10000}
file 8 {1: 300, 2: 10000, 3: 10000, 4: 10000, 5: 10000, 6: 10000}
file 9 {1: 300, 2: 10000, 3: 10000, 4: 10000, 5: 10000, 6: 10000}
file 10 {1: 300, 2: 10000, 3: 10000, 4: 10000, 5: 10000, 6: 10000}
file 11 {1: 300, 2: 10000, 3: 10000, 4: 10000, 5: 10000, 6: 10000}
file 12 {1: 300, 2: 10000, 3: 10000, 4: 10000, 5: 10000, 6: 10000}
file 13 {1: 300, 2: 10000, 3: 10000, 4: 10000, 5: 10000, 6: 10000}
file 14 {1: 300, 2: 10000, 3: 10000, 4: 10000, 5: 10000, 6: 10000}
file 