In [1]:
import collections
import math
from pathlib import Path
import itertools
import subprocess

import beartype
from beartype.typing import *
import matplotlib.pyplot as plt
import more_itertools
import numpy as np
import torch
import ujson as json
import jsonlines as jsonl
import rich
from tqdm import tqdm

import datagen


def cmd(command: list[str]) -> list[str]:
    return subprocess.check_output(command).decode("utf-8").strip().split("\n")

def only_one(it: Iterable):
    iterated = iter(it)
    good = next(iterated)
    for bad in iterated:
        raise ValueError("Expected only one item, got more than one.")
    return good

def check_len(it: Sequence, expected_len: int) -> Sequence:
    if not len(it) == expected_len:
        raise ValueError(f"Expected {expected_len} items, got {len(it)}.")
    return it

def count_lines(path: Path) -> int:
    return int(check_len(only_one(cmd(["wc", "-l", str(path)])).split(), 2)[0])




In [2]:
TARGET_DIR = Path("log_results/basic")
MIN_LENGTH = 10

print(list(TARGET_DIR.iterdir()))
active = []
for file in TARGET_DIR.iterdir():
    target = file / "predictions.jsonl"
    if target.exists():
        active.append(target)
lengths_active = [count_lines(file) for file in active]
print(collections.Counter(lengths_active))
selected = [file for file, length in more_itertools.zip_equal(active, lengths_active) if length >= MIN_LENGTH]
new_lengths = [count_lines(file) for file in selected]
print(collections.Counter(new_lengths))
print(len(selected))
min_length = min(new_lengths)
print(f"{min_length = }")

[PosixPath('log_results/basic/shared_config.json'), PosixPath('log_results/basic/basic_0'), PosixPath('log_results/basic/basic_1'), PosixPath('log_results/basic/basic_2'), PosixPath('log_results/basic/basic_3'), PosixPath('log_results/basic/basic_4'), PosixPath('log_results/basic/basic_5'), PosixPath('log_results/basic/basic_6'), PosixPath('log_results/basic/basic_7'), PosixPath('log_results/basic/basic_8'), PosixPath('log_results/basic/basic_9'), PosixPath('log_results/basic/basic_10'), PosixPath('log_results/basic/basic_11'), PosixPath('log_results/basic/basic_12'), PosixPath('log_results/basic/basic_13'), PosixPath('log_results/basic/basic_14'), PosixPath('log_results/basic/basic_15'), PosixPath('log_results/basic/basic_16'), PosixPath('log_results/basic/basic_17'), PosixPath('log_results/basic/basic_18'), PosixPath('log_results/basic/basic_19'), PosixPath('log_results/basic/basic_20'), PosixPath('log_results/basic/basic_21'), PosixPath('log_results/basic/basic_22'), PosixPath('log_

In [None]:
readers = [jsonl.open(file, "r") for file in selected]

for i, reader in enumerate(readers):
    epochs_seen = set()
    l0 = None
    l1 = None
    for line in reader:
        if line["epoch"] in epochs_seen:
            print(f"file {i} already seen epoch {line['epoch']}")
        epochs_seen.add(line["epoch"])
        if l0 is None:
            l0 = line
            print(l0)
        elif l1 is None:
            l1 = line
            print(l1)
            
    print(l0 == l1)


In [7]:
readers = [jsonl.open(file, "r") for file in selected]

file_iterator = [iter(reader) for reader in readers]
"""
epoch:
- {results: {input_eqn_str: {is_freeform_bool: {'per_batch_mode': tensor_list}}}}
- 

"""
counts = collections.defaultdict(lambda: collections.defaultdict(int))
counts_levels = collections.defaultdict(lambda: collections.defaultdict(int))

for epoch_num in tqdm(range(min_length)):
    eq = None
    print(f"real epoch: {epoch_num}")
    for file_idx, per_file in enumerate(file_iterator):
        per_file_per_epoch = next(per_file)
        if epoch_num == 0:
            continue

        assert per_file_per_epoch["epoch"] == epoch_num - 1, (per_file_per_epoch["epoch"], epoch_num - 1)
        keys = per_file_per_epoch["results"].keys()
        
        for key in keys:
            
            counts[per_file_per_epoch["epoch"]][key.strip()] += 1

            if per_file_per_epoch["epoch"] == 0:
                level = datagen.tree_depth_from_str(key)
                counts_levels[file_idx][level] += 1

    


  0%|          | 0/17 [00:00<?, ?it/s]

real epoch: 0


  6%|▌         | 1/17 [00:00<00:06,  2.30it/s]

real epoch: 1


 12%|█▏        | 2/17 [00:06<00:55,  3.70s/it]

real epoch: 2


 18%|█▊        | 3/17 [00:06<00:31,  2.22s/it]

real epoch: 3


 24%|██▎       | 4/17 [00:07<00:19,  1.52s/it]

real epoch: 4


 29%|██▉       | 5/17 [00:07<00:13,  1.15s/it]

real epoch: 5


 35%|███▌      | 6/17 [00:08<00:10,  1.02it/s]

real epoch: 6


 41%|████      | 7/17 [00:09<00:08,  1.16it/s]

real epoch: 7


 47%|████▋     | 8/17 [00:09<00:07,  1.25it/s]

real epoch: 8


 53%|█████▎    | 9/17 [00:10<00:06,  1.21it/s]

real epoch: 9


 59%|█████▉    | 10/17 [00:11<00:06,  1.03it/s]

real epoch: 10


 65%|██████▍   | 11/17 [00:13<00:07,  1.17s/it]

real epoch: 11


 71%|███████   | 12/17 [00:14<00:06,  1.22s/it]

real epoch: 12


 76%|███████▋  | 13/17 [00:16<00:04,  1.19s/it]

real epoch: 13


 82%|████████▏ | 14/17 [00:17<00:03,  1.22s/it]

real epoch: 14


 88%|████████▊ | 15/17 [00:18<00:02,  1.25s/it]

real epoch: 15


 94%|█████████▍| 16/17 [00:19<00:01,  1.26s/it]

real epoch: 16


100%|██████████| 17/17 [00:21<00:00,  1.26s/it]


TypeError: 'int' object is not subscriptable

In [10]:
for k, v in counts.items():
    print("epoch", k, collections.Counter(v.values()))

for k, v in counts_levels.items():
    print("file", k, dict(sorted(v.items(), key=lambda x: x[0])))


epoch 0 Counter({1: 50915, 2: 1264, 3: 19, 4: 1})
epoch 1 Counter({1: 50915, 2: 1264, 3: 19, 4: 1})
epoch 2 Counter({1: 50915, 2: 1264, 3: 19, 4: 1})
epoch 3 Counter({1: 50915, 2: 1264, 3: 19, 4: 1})
epoch 4 Counter({1: 50915, 2: 1264, 3: 19, 4: 1})
epoch 5 Counter({1: 50915, 2: 1264, 3: 19, 4: 1})
epoch 6 Counter({1: 50915, 2: 1264, 3: 19, 4: 1})
epoch 7 Counter({1: 50915, 2: 1264, 3: 19, 4: 1})
epoch 8 Counter({1: 50915, 2: 1264, 3: 19, 4: 1})
epoch 9 Counter({1: 50915, 2: 1264, 3: 19, 4: 1})
epoch 10 Counter({1: 50915, 2: 1264, 3: 19, 4: 1})
epoch 11 Counter({1: 50915, 2: 1264, 3: 19, 4: 1})
epoch 12 Counter({1: 50915, 2: 1264, 3: 19, 4: 1})
epoch 13 Counter({1: 50915, 2: 1264, 3: 19, 4: 1})
epoch 14 Counter({1: 50915, 2: 1264, 3: 19, 4: 1})
epoch 15 Counter({1: 50915, 2: 1264, 3: 19, 4: 1})
file 0 {1: 3, 2: 987, 3: 974, 4: 950, 5: 988, 6: 962}
file 1 {2: 916, 3: 934, 4: 1007, 5: 1020, 6: 987}
file 2 {1: 1, 2: 958, 3: 956, 4: 994, 5: 940, 6: 1015}
file 3 {1: 2, 2: 948, 3: 963, 4: 98