In [3]:
import os
import json
import re
import random
from collections import defaultdict

_TOKEN_RE = re.compile(r"<[^>]+>")

def augment_dataset_with_new_atomic(  # function for fine tuning dataset
    old_data_path: str,
    new_data_path: str,
    *,
    num_new_atomic: int = 2000,
    keep_old_train: int = 8000,
    cap_test_inferred_new: int = 3000,
    cap_test_inferred_mix: int = 3000,
    rng_seed: int = 0,
    overwrite: bool = True,
    verify_atomic_inferred: bool = True,
    verify_print_topk: int = 10,
):
    """
    Create a new dataset by keeping a subset of the old train set and adding new atomic facts.

    Behavior is equivalent to the previous implementation with `half_ood=False`:
    - Output is written to `new_data_path`.
    - `test_inferred_new_mix` only composes with kept old atomic facts so that every old fact
      referenced by the mix split is guaranteed to be present in the new train set.
    """
    rng = random.Random(rng_seed)

    out_data_path = new_data_path

    if os.path.abspath(old_data_path) == os.path.abspath(out_data_path):
        raise ValueError(
            "old_data_path must not be the same as the output path; otherwise it would overwrite the original dataset."
        )

    def load_json(p):
        with open(p, "r", encoding="utf-8") as f:
            return json.load(f)

    def dump_json(obj, p):
        with open(p, "w", encoding="utf-8") as f:
            json.dump(obj, f)

    def toks(s: str):
        return _TOKEN_RE.findall(s)

    def make_atomic(h: str, r: str, t: str):
        inp = f"{h}{r}"
        return {"input_text": inp, "target_text": f"{inp}{t}</a>"}

    def make_inferred(a: str, r1: str, r2: str, t: str):
        inp = f"{a}{r1}{r2}"
        return {"input_text": inp, "target_text": f"{inp}{t}</a>"}

    # ---- load old dataset ----
    train_path = os.path.join(old_data_path, "train.json")
    valid_path = os.path.join(old_data_path, "valid.json")
    test_path = os.path.join(old_data_path, "test.json")
    vocab_path = os.path.join(old_data_path, "vocab.json")

    if not (
        os.path.exists(train_path)
        and os.path.exists(valid_path)
        and os.path.exists(test_path)
        and os.path.exists(vocab_path)
    ):
        raise FileNotFoundError(
            "old_data_path must contain train.json, valid.json, test.json, vocab.json"
        )

    old_train_full = load_json(train_path)
    old_valid = load_json(valid_path)
    old_test = load_json(test_path)
    vocab_raw = load_json(vocab_path)

    # Vocab compatibility: vocab.json can be either list[token] or dict[token->id]
    if isinstance(vocab_raw, dict):
        vocab_tokens = list(vocab_raw.keys())
    else:
        vocab_tokens = list(vocab_raw)

    entities = [t for t in vocab_tokens if isinstance(t, str) and t.startswith("<e_")]
    relations = [t for t in vocab_tokens if isinstance(t, str) and t.startswith("<r_")]
    if not entities or not relations:
        raise ValueError(
            "Failed to parse entities/relations from vocab.json (expected <e_*/<r_*> tokens)."
        )

    # ---- split old train into atomic vs inferred ----
    old_atomic_items = []
    old_inferred_items = []
    for item in old_train_full:
        itoks = toks(item.get("input_text", ""))
        if len(itoks) == 2:
            old_atomic_items.append(item)
        elif len(itoks) == 3:
            old_inferred_items.append(item)

    # ---- build atomic map: (h,r) -> t ----
    atomic_key_to_item = {}
    atomic_map = {}
    used_pairs = set()
    used_rels_by_head = defaultdict(set)

    # Keep a full copy of old facts to avoid collisions when generating new edges.
    old_atomic_out = defaultdict(list)
    old_atomic_edges = []

    unresolved_atomic = 0
    for item in old_atomic_items:
        it = item.get("input_text", "")
        tt = item.get("target_text", "")
        itoks = toks(it)
        ttoks = toks(tt)
        if len(itoks) != 2 or len(ttoks) < 3:
            unresolved_atomic += 1
            continue
        h, r = itoks[0], itoks[1]
        t = ttoks[2]
        key = (h, r)

        if key not in atomic_key_to_item:
            atomic_key_to_item[key] = item
        if key not in atomic_map:
            atomic_map[key] = t

        used_pairs.add(key)
        used_rels_by_head[h].add(r)
        old_atomic_out[h].append((r, t))
        old_atomic_edges.append((h, r, t))

    if not atomic_key_to_item:
        raise RuntimeError(
            "No valid atomic facts parsed from old train (expected len(tokens)==2)."
        )

    # ---- build dependency: atomic_key -> inferred indices that use it (at least one hop) ----
    atomic_key_to_inferred = defaultdict(set)
    unresolved_inferred = 0

    for idx, item in enumerate(old_inferred_items):
        it = item.get("input_text", "")
        tt = item.get("target_text", "")
        itoks = toks(it)
        ttoks = toks(tt)

        if len(itoks) != 3 or len(ttoks) < 4:
            unresolved_inferred += 1
            continue

        a, r1, r2 = itoks[0], itoks[1], itoks[2]
        key1 = (a, r1)
        b = atomic_map.get(key1, None)
        if b is None:
            atomic_key_to_inferred[key1].add(idx)
            unresolved_inferred += 1
            continue

        key2 = (b, r2)
        atomic_key_to_inferred[key1].add(idx)
        atomic_key_to_inferred[key2].add(idx)

    # ---- select old train by sampling atomic facts with closure of inferred ----
    all_atomic_keys = list(atomic_key_to_item.keys())
    rng.shuffle(all_atomic_keys)

    selected_atomic = set()
    selected_inferred = set()
    total_selected = 0

    def marginal_cost(key):
        add_atomic = 0 if key in selected_atomic else 1
        add_inferred = atomic_key_to_inferred.get(key, set()) - selected_inferred
        return add_atomic + len(add_inferred), add_atomic, add_inferred

    for key in all_atomic_keys:
        cost, _, add_inferred = marginal_cost(key)
        if total_selected + cost <= keep_old_train:
            selected_atomic.add(key)
            selected_inferred |= add_inferred
            total_selected += cost
            if total_selected == keep_old_train:
                break

    if total_selected < keep_old_train:
        remaining = [k for k in all_atomic_keys if k not in selected_atomic]
        rng.shuffle(remaining)
        remaining.sort(key=lambda k: marginal_cost(k)[0])

        for key in remaining:
            cost, _, add_inferred = marginal_cost(key)
            if cost == 0:
                continue
            if total_selected + cost <= keep_old_train:
                selected_atomic.add(key)
                selected_inferred |= add_inferred
                total_selected += cost
                if total_selected == keep_old_train:
                    break

    # ---- verify: closure for selected atomic keys ----
    if verify_atomic_inferred:
        missing_total = 0
        zero_dep = 0
        dep_sizes = []
        worst_missing = 0
        worst_key = None

        for key in selected_atomic:
            deps = atomic_key_to_inferred.get(key, set())
            dep_sizes.append(len(deps))
            if len(deps) == 0:
                zero_dep += 1
            missing = deps - selected_inferred
            if missing:
                missing_total += len(missing)
                if len(missing) > worst_missing:
                    worst_missing = len(missing)
                    worst_key = key

        print(
            "[verify] selected_atomic:",
            len(selected_atomic),
            "selected_inferred:",
            len(selected_inferred),
            "total_selected:",
            total_selected,
            f"(target={keep_old_train})",
        )
        if dep_sizes:
            dep_sizes_sorted = sorted(dep_sizes, reverse=True)
            print(
                "[verify] atomic->inferred dependency sizes:",
                f"min={min(dep_sizes)}, median={dep_sizes_sorted[len(dep_sizes)//2]}, max={max(dep_sizes)}",
            )
            print("[verify] selected atomic with dep_size==0:", zero_dep)

        if missing_total != 0:
            raise RuntimeError(
                f"[verify] Missing closure: selected_atomic has inferred dependencies that were not included. "
                f"missing_total={missing_total}, worst_key={worst_key}, worst_missing={worst_missing}"
            )
        else:
            print(
                "[verify] OK: For each selected atomic key, all dependent train inferred items (>=1 hop) are included."
            )

        if verify_print_topk > 0:
            sample_keys = list(selected_atomic)
            rng.shuffle(sample_keys)
            sample_keys = sample_keys[:verify_print_topk]
            for k in sample_keys:
                deps = atomic_key_to_inferred.get(k, set())
                print(f"[verify] sample atomic_key={k} dep_inferred_count={len(deps)}")

    # ---- assemble old_train_kept ----
    old_train_kept = []
    for key in selected_atomic:
        old_train_kept.append(atomic_key_to_item[key])
    for idx in sorted(selected_inferred):
        old_train_kept.append(old_inferred_items[idx])
    rng.shuffle(old_train_kept)

    # Only build structures from old atomic facts that are actually kept in train,
    # so that every old fact referenced by test_inferred_new_mix is guaranteed to be present in train.
    kept_old_atomic_out = defaultdict(list)
    kept_old_atomic_edges = []
    for (h, r) in selected_atomic:
        t = atomic_map.get((h, r), None)
        if t is None:
            continue
        kept_old_atomic_out[h].append((r, t))
        kept_old_atomic_edges.append((h, r, t))

    # ---- old test keys for dedup ----
    old_test_key = set()
    for item in old_test:
        old_test_key.add((item.get("input_text"), item.get("target_text"), item.get("type")))

    # ---- generate new atomic edges ----
    new_edges = []
    new_pairs = set()
    new_rels_by_head = defaultdict(set)

    def pick_unused_relation(head: str):
        used = used_rels_by_head.get(head, set()) | new_rels_by_head.get(head, set())
        candidates = [r for r in relations if r not in used]
        if not candidates:
            return None
        return rng.choice(candidates)

    target_pairs = num_new_atomic // 2
    attempts = 0
    while len(new_edges) < 2 * target_pairs and attempts < 500000:
        attempts += 1
        a = rng.choice(entities)
        b = rng.choice(entities)
        t = rng.choice(entities)

        r1 = pick_unused_relation(a)
        r2 = pick_unused_relation(b)
        if r1 is None or r2 is None:
            continue
        if (a, r1) in used_pairs or (a, r1) in new_pairs:
            continue
        if (b, r2) in used_pairs or (b, r2) in new_pairs:
            continue

        new_edges.append((a, r1, b))
        new_pairs.add((a, r1))
        new_rels_by_head[a].add(r1)

        new_edges.append((b, r2, t))
        new_pairs.add((b, r2))
        new_rels_by_head[b].add(r2)

    attempts = 0
    while len(new_edges) < num_new_atomic and attempts < 500000:
        attempts += 1
        h = rng.choice(entities)
        r = pick_unused_relation(h)
        if r is None:
            continue
        if (h, r) in used_pairs or (h, r) in new_pairs:
            continue
        t = rng.choice(entities)

        new_edges.append((h, r, t))
        new_pairs.add((h, r))
        new_rels_by_head[h].add(r)

    if len(new_edges) < num_new_atomic:
        raise RuntimeError(
            f"Only generated {len(new_edges)} new atomic facts (constraints may be too strict / not enough relations)."
        )

    # ---- build new train ----
    new_atomic_items = [make_atomic(h, r, t) for (h, r, t) in new_edges]
    new_train = list(old_train_kept) + new_atomic_items

    # ---- build inferred probes ----
    new_out = defaultdict(list)
    for (h, r, t) in new_edges:
        new_out[h].append((r, t))

    inferred_new = []
    for (a, r1, b) in new_edges:
        for (r2, t) in new_out.get(b, []):
            inferred_new.append(make_inferred(a, r1, r2, t))

    # ---- test_inferred_new_mix (restricted to kept old facts to guarantee coverage in train) ----
    inferred_mix = []
    # new -> old
    for (a, r1, b) in new_edges:
        for (r2, t) in kept_old_atomic_out.get(b, []):
            inferred_mix.append(make_inferred(a, r1, r2, t))
    # old -> new
    for (a, r1, b) in kept_old_atomic_edges:
        for (r2, t) in new_out.get(b, []):
            inferred_mix.append(make_inferred(a, r1, r2, t))

    rng.shuffle(inferred_new)
    rng.shuffle(inferred_mix)
    inferred_new = inferred_new[:cap_test_inferred_new]
    inferred_mix = inferred_mix[:cap_test_inferred_mix]

    if len(inferred_mix) < cap_test_inferred_mix:
        print(
            "[augment_dataset_with_new_atomic] warning: inferred_mix only",
            len(inferred_mix),
            "< cap_test_inferred_mix=",
            cap_test_inferred_mix,
            "(mix now restricted to kept old facts to guarantee coverage in train)",
        )

    # ---- assemble new test ----
    new_test = list(old_test)

    for item in new_atomic_items:
        probe = dict(item)
        probe["type"] = "atomic_new"
        key = (probe["input_text"], probe["target_text"], probe["type"])
        if key not in old_test_key:
            new_test.append(probe)
            old_test_key.add(key)

    for item in inferred_new:
        probe = dict(item)
        probe["type"] = "test_inferred_new"
        key = (probe["input_text"], probe["target_text"], probe["type"])
        if key not in old_test_key:
            new_test.append(probe)
            old_test_key.add(key)

    for item in inferred_mix:
        probe = dict(item)
        probe["type"] = "test_inferred_new_mix"
        key = (probe["input_text"], probe["target_text"], probe["type"])
        if key not in old_test_key:
            new_test.append(probe)
            old_test_key.add(key)

    # ---- write new dataset (overwrite) ----
    os.makedirs(out_data_path, exist_ok=True)

    if overwrite:
        for fn in ["train.json", "valid.json", "test.json", "vocab.json"]:
            fp = os.path.join(out_data_path, fn)
            if os.path.exists(fp):
                os.remove(fp)

    dump_json(new_train, os.path.join(out_data_path, "train.json"))
    dump_json(old_valid, os.path.join(out_data_path, "valid.json"))
    dump_json(new_test, os.path.join(out_data_path, "test.json"))
    dump_json(vocab_raw, os.path.join(out_data_path, "vocab.json"))

    # ---- prints ----
    print("[augment_dataset_with_new_atomic] old_train_full:", len(old_train_full))
    print(
        "[augment_dataset_with_new_atomic] parsed atomic:",
        len(old_atomic_items),
        "parsed inferred:",
        len(old_inferred_items),
    )
    print("[augment_dataset_with_new_atomic] kept old:", len(old_train_kept), f"(target={keep_old_train})")
    print("[augment_dataset_with_new_atomic] added new atomic:", len(new_atomic_items), f"(target={num_new_atomic})")
    print("[augment_dataset_with_new_atomic] new train total:", len(new_train))
    print(
        "[augment_dataset_with_new_atomic] test additions (after caps):",
        "atomic_new=",
        len(new_atomic_items),
        "test_inferred_new=",
        len(inferred_new),
        "test_inferred_new_mix=",
        len(inferred_mix),
    )

    if unresolved_atomic or unresolved_inferred:
        print(
            "[augment_dataset_with_new_atomic] warnings:",
            f"unresolved_atomic={unresolved_atomic}, unresolved_inferred={unresolved_inferred}",
        )


# Load tokenizer
from transformers import GPT2Tokenizer
import json
import os
import numpy as np
def preprocess_data(data_path, save_path):
    max_seq_len = 5
    all_input_ids = []
    all_output_ids = []
    puzzle_index = [0]
    group_index = [0]
    puzzle_identifier = []
    group_index = [0]
    puzzle_identifier = []
    with open(data_path, "r") as f:
        data = json.load(f)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    for index, data_piece in enumerate(data):
        input_text = data_piece['input_text']
        output_text = data_piece['target_text']
        input_tokens = tokenizer.tokenize(input_text)
        output_tokens = tokenizer.tokenize(output_text)
        input_ids = [int(vocab_map[token]) for token in input_tokens]
        output_ids = [int(vocab_map[token]) for token in output_tokens]
        if len(input_ids) < max_seq_len:
            pad_len = max_seq_len - len(input_ids)
            pad_seq = [0] * pad_len
            input_ids.extend(pad_seq)
        if len(output_ids) < max_seq_len:
            pad_len = max_seq_len - len(output_ids)
            pad_seq = [0] * pad_len
            output_ids.extend(pad_seq)
        # print(f"{index}/{len(data)}", end="\r")
        all_input_ids.append(input_ids)
        all_output_ids.append(output_ids)
        puzzle_index.append(index + 1)
        group_index.append(index + 1)
        puzzle_identifier.append(0)
    all_input_ids = np.array(all_input_ids, dtype=np.int32)
    all_output_ids = np.array(all_output_ids, dtype=np.int32)
    puzzle_index = np.array(puzzle_index, dtype=np.int32)
    group_index = np.array(group_index, dtype=np.int32)
    dataset_json = {
        "pad_id": 0,
        "ignore_label_id": 0,
        "blank_identifier_id": 0,
        "vocab_size": len(vocab_map),
        "seq_len": max_seq_len,
        "num_puzzle_identifiers": 1,
        "total_groups": len(data),
        "mean_puzzle_examples": 1.0,
        "total_puzzles": len(data),
        "sets": ["all"],
    }
    np.save(os.path.join(save_path, 'all__group_indices.npy'), group_index)
    np.save(os.path.join(save_path, 'all__inputs.npy'), all_input_ids)
    np.save(os.path.join(save_path, 'all__labels.npy'), all_output_ids)
    np.save(os.path.join(save_path, 'all__puzzle_identifiers.npy'), puzzle_identifier)
    np.save(os.path.join(save_path, 'all__puzzle_indices.npy'), puzzle_index)
    with open(os.path.join(save_path, 'dataset.json'), 'w') as f:
        json.dump(dataset_json, f)


# Pre-training data preprocess

In [9]:
import os
import json
from tqdm.notebook import tqdm


# specify your datapath (Please make sure you have already trained on traditional Transformer)
# dataset_name = 'composition.2000.200.18.0_factaug_h1ratio0.0_h1k0_h2ratio0.0_h2k0' # Natrual Grokking
dataset_name = 'composition.2000.200.18.0_factaug_h1ratio0.5_h1k9_h2ratio0.5_h2k9' 
data_root_path_name = f'../Grokking_analysis/data/{dataset_name}'

checkpoint_path = f"../Grokking_analysis/output/{dataset_name}/checkpoint-2000"

print(f"\nüìÇ Checkpoint path: {checkpoint_path}")
print(f"‚úì Check if path exists: {os.path.exists(checkpoint_path)}")

# Method 1: Load using transformers directly
print(f"\n{'='*80}")
print(f"Method 1: Using transformers.GPT2Tokenizer.from_pretrained()")
print(f"{'='*80}")

try:
    tokenizer = GPT2Tokenizer.from_pretrained(checkpoint_path)
    print(f"‚úÖ Tokenizer loaded successfully!")
    print(f"   Vocabulary size: {len(tokenizer)}")
    print(f"   Special tokens: {tokenizer.special_tokens_map}")
    
    # Test tokenization
    test_text = "<e_0><r_1><a>e_2</a>"
    tokens = tokenizer.tokenize(test_text)
    token_ids = tokenizer.encode(test_text)
    
    print(f"\nüß™ Test Tokenization:")
    print(f"   Input text: {test_text}")
    print(f"   Tokens: {tokens}")
    print(f"   Token IDs: {token_ids}")
    print(f"   Decoded text: {tokenizer.decode(token_ids)}")
    
except Exception as e:
    print(f"‚ùå Loading failed: {e}")

# Method 2: Manually load vocab.json (lower-level approach)
print(f"\n{'='*80}")
print(f"Method 2: Manually build mapping from vocab.json")
print(f"{'='*80}")

vocab_path = os.path.join(checkpoint_path, "vocab.json")
with open(vocab_path, "r") as f:
    vocab = json.load(f)

# ÊûÑÂª∫ token <-> id Êò†Â∞Ñ
token2id = vocab  # vocab.json Â∑≤ÁªèÊòØ {token: id} Ê†ºÂºè
id2token = {v: k for k, v in token2id.items()}

print(f"‚úÖ Vocabulary loaded successfully!")
print(f"   Vocabulary size: {len(token2id)}")
print(f"   First 10 tokens: {list(token2id.keys())[:10]}")
print(f"   Last 10 tokens: {list(token2id.keys())[-10:]}")

# Test manual mapping
test_tokens = ["<e_0>", "<r_1>", "<a>"]
test_ids = [token2id.get(t, token2id.get("<unk>", 0)) for t in test_tokens]
decoded_tokens = [id2token.get(i, "<unk>") for i in test_ids]

print(f"\nüß™ Test manual mapping:")
print(f"   Tokens: {test_tokens}")
print(f"   IDs: {test_ids}")
print(f"   Decoded: {decoded_tokens}")

# Method 3: View all files in the checkpoint
print(f"\n{'='*80}")
print(f"Method 3: View all files in the checkpoint")
print(f"{'='*80}")

checkpoint_files = os.listdir(checkpoint_path)
print(f"üìã Files in the checkpoint:")
for f in sorted(checkpoint_files):
    file_path = os.path.join(checkpoint_path, f)
    if os.path.isfile(file_path):
        size = os.path.getsize(file_path)
        print(f"   - {f:30s} ({size:>12,} bytes)")

# Special token information
print(f"\n{'='*80}")
print(f"üéØ Special Token Analysis")
print(f"{'='*80}")

special_tokens = ["<mask>", "<sep>", "<a>", "</a>", "<q>", "</q>"]
print(f"Dataset-defined special tokens:")
for token in special_tokens:
    if token in token2id:
        print(f"   {token:10s} ‚Üí ID: {token2id[token]}")

# Entity and relation token statistics
entity_tokens = [k for k in token2id.keys() if k.startswith("<e_")]
relation_tokens = [k for k in token2id.keys() if k.startswith("<r_")]
print(f"\nüìä Token type statistics:")
print(f"   Entity tokens (<e_*): {len(entity_tokens)}")
print(f"   Relation tokens (<r_*): {len(relation_tokens)}")
print(f"   Special tokens: {len(special_tokens)}")
print(f"   Total: {len(token2id)}")


vocab_map = {'<PAD>':0}
with open(f"{data_root_path_name}/vocab.json", "r") as f:
    vocab = json.load(f)
vocab_map = {token: idx+1 for idx, token in enumerate(vocab)}
reverse_vocab_map = {idx: token for token, idx in vocab_map.items()}
if not os.path.exists(f"{data_root_path_name}/reverse_vocab_map.json"):
    with open(f"{data_root_path_name}/reverse_vocab_map.json", "w") as f:
        json.dump(reverse_vocab_map, f)

if not os.path.exists(f"{data_root_path_name}/vocab_map.json"):
    with open(f"{data_root_path_name}/vocab_map.json", "w") as f:
        json.dump(vocab_map, f)
        

preprocess_data(f"{data_root_path_name}/test.json",
                f'data/{dataset_name}/test')
preprocess_data(f"{data_root_path_name}/train.json",
                f'data/{dataset_name}/train')
preprocess_data(f"{data_root_path_name}/valid.json",
                f'data/{dataset_name}/valid')

identifiers_file = ["<blank>"]
with open(f'data/{dataset_name}/identifiers.json', 'w') as f:
    json.dump(identifiers_file, f)


üìÇ Checkpoint path: ../Grokking_analysis/output/composition.2000.200.18.0_factaug_h1ratio0.5_h1k9_h2ratio0.5_h2k9/checkpoint-2000
‚úì Check if path exists: True

Method 1: Using transformers.GPT2Tokenizer.from_pretrained()
‚úÖ Tokenizer loaded successfully!
   Vocabulary size: 52463
   Special tokens: {'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}

üß™ Test Tokenization:
   Input text: <e_0><r_1><a>e_2</a>
   Tokens: ['<e_0>', '<r_1>', '<a>', 'e', '_', '2', '</a>']
   Token IDs: [50257, 52258, 52459, 68, 62, 17, 52460]
   Decoded text: <e_0> <r_1> <a> e_2 </a>

Method 2: Manually build mapping from vocab.json
‚úÖ Vocabulary loaded successfully!
   Vocabulary size: 50257
   First 10 tokens: ['!', '!!', '!!!', '!!!!', '!!!!!', '!!!!!!!!', '!!"', '!"', '!",', '!".']
   Last 10 tokens: ['ƒº√©ƒ®ƒ¥', 'ƒΩ', 'ƒæ', 'ƒø', '≈Ä', '≈Å', '≈Ç', '≈É', '≈É¬∑', '≈Éƒ∂']

üß™ Test manual mapping:
   Tokens: ['<e_0>', '<r_1>', '

### Using the following code to finetuning on the training model

- ```run_name="pretrain_grok_composition" && CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain_grok_evaluate_ver_0_1.py   arch=trm   data_paths="[data/composition.2000.200.18.0_factaug_h1ratio0.5_h1k9_h2ratio0.5_h2k9]"   evaluators="[]"   epochs=1100   eval_interval=5   lr=4e-5   puzzle_emb_lr=1e-4   weight_decay=1.0   puzzle_emb_weight_decay=1.0   arch.mlp_t=True   arch.pos_encodings=None   arch.L_layers=2   arch.H_cycles=2   arch.L_cycles=6 arch.halt_max_steps=1 arch.hidden_size=1536 +run_name=${run_name}  ema=True   global_batch_size=512  +max_inference_steps=1 checkpoint_every_eval=True   +format="maintain_prefix" +causal=False +post_fix="anything_here_you_like"```

# Finetuning data preprocess (Please make sure you have already trained with traditional Transformer on origional dataset)

In [11]:
# add new non-seen facts into the dataset (No augmentation version)
# dataset_name = 'composition.2000.200.18.0_factaug_h1ratio0.0_h1k0_h2ratio0.0_h2k0' # Natrual Grokking
dataset_name = 'composition.2000.200.18.0_factaug_h1ratio0.5_h1k9_h2ratio0.5_h2k9' 
old_data_path = f'../Grokking_analysis/data/{dataset_name}' # origional dataset path
new_data_path = old_data_path+'_finetuning'

augment_dataset_with_new_atomic(old_data_path,new_data_path)
with open(f'{new_data_path}/test.json','r') as f:
    asd = json.load(f)
new_atomic_count = 0
new_infer_count = 0
mix_infer_count = 0

for i in asd:
    if (i['type'] == 'test_inferred_new'):
        new_infer_count += 1
    if (i['type'] == 'test_inferred_new_mix'):
        mix_infer_count += 1
    if (i['type'] == 'atomic_new'):
        new_atomic_count += 1
print(new_atomic_count, new_infer_count, mix_infer_count)

import os
import json
from tqdm.notebook import tqdm
print("="*80)
print("üîß Load Tokenizer from checkpoints")
print("="*80)


data_root_path_name = new_data_path.split('/')[-1] # the datapath need to be transferred
old_data_folder_name = old_data_path.split('/')[-1]

checkpoint_path = f"../Grokking_analysis/output/{dataset_name}/checkpoint-2000" 


tokenizer = GPT2Tokenizer.from_pretrained(checkpoint_path)


vocab_path = os.path.join(checkpoint_path, "vocab.json")
with open(vocab_path, "r") as f:
    vocab = json.load(f)

# construct token <-> id mapping
token2id = vocab  # vocab.json already in {token: id} format
id2token = {v: k for k, v in token2id.items()}



# test mapping
test_tokens = ["<e_0>", "<r_1>", "<a>"]
test_ids = [token2id.get(t, token2id.get("<unk>", 0)) for t in test_tokens]
decoded_tokens = [id2token.get(i, "<unk>") for i in test_ids]

print(f"\nüß™ test mapping:")
print(f"   Tokens: {test_tokens}")
print(f"   IDs: {test_ids}")
print(f"   decode: {decoded_tokens}")


checkpoint_files = os.listdir(checkpoint_path)
print(f"üìã Checkpoint files:")
for f in sorted(checkpoint_files):
    file_path = os.path.join(checkpoint_path, f)
    if os.path.isfile(file_path):
        size = os.path.getsize(file_path)
        print(f"   - {f:30s} ({size:>12,} bytes)")

# speacial token information
print(f"\n{'='*80}")
print(f"üéØ special Token analysis")
print(f"{'='*80}")

special_tokens = ["<mask>", "<sep>", "<a>", "</a>", "<q>", "</q>"]
print(f"special token:")
for token in special_tokens:
    if token in token2id:
        print(f"   {token:10s} ‚Üí ID: {token2id[token]}")

# entity relation token statics
entity_tokens = [k for k in token2id.keys() if k.startswith("<e_")]
relation_tokens = [k for k in token2id.keys() if k.startswith("<r_")]
print(f"\nüìä Token type statics:")
print(f"   entity token (<e_*): {len(entity_tokens)}")
print(f"   relation token (<r_*): {len(relation_tokens)}")
print(f"   special token: {len(special_tokens)}")
print(f"   total token count: {len(token2id)}")


vocab_map = {'<PAD>':0}
with open(f"../Grokking_analysis/data/{data_root_path_name}/vocab.json", "r") as f:
    vocab = json.load(f)
vocab_map = {token: idx+1 for idx, token in enumerate(vocab)}
reverse_vocab_map = {idx: token for token, idx in vocab_map.items()}
if not os.path.exists(f"../Grokking_analysis/data/{data_root_path_name}/reverse_vocab_map.json"):
    with open(f"../Grokking_analysis/data/{data_root_path_name}/reverse_vocab_map.json", "w") as f:
        json.dump(reverse_vocab_map, f)

if not os.path.exists(f"../Grokking_analysis/data/{data_root_path_name}/vocab_map.json"):
    with open(f"../Grokking_analysis/data/{data_root_path_name}/vocab_map.json", "w") as f:
        json.dump(vocab_map, f)
        

preprocess_data(f"../Grokking_analysis/data/{data_root_path_name}/test.json",
                f'data/{data_root_path_name}/test')
preprocess_data(f"../Grokking_analysis/data/{data_root_path_name}/train.json",
                f'data/{data_root_path_name}/train')
preprocess_data(f"../Grokking_analysis/data/{data_root_path_name}/valid.json",
                f'data/{data_root_path_name}/valid')

identifiers_file = ["<blank>"]
with open(f'/home/kxh230002/TRM_LCM/TinyRecursiveModels/data/grokking/{data_root_path_name}/identifiers.json', 'w') as f:
    json.dump(identifiers_file, f)

[verify] selected_atomic: 220 selected_inferred: 7780 total_selected: 8000 (target=8000)
[verify] atomic->inferred dependency sizes: min=0, median=37, max=48
[verify] selected atomic with dep_size==0: 4
[verify] OK: For each selected atomic key, all dependent train inferred items (>=1 hop) are included.
[verify] sample atomic_key=('<e_286>', '<r_147>') dep_inferred_count=41
[verify] sample atomic_key=('<e_632>', '<r_21>') dep_inferred_count=30
[verify] sample atomic_key=('<e_1853>', '<r_71>') dep_inferred_count=45
[verify] sample atomic_key=('<e_209>', '<r_86>') dep_inferred_count=28
[verify] sample atomic_key=('<e_491>', '<r_178>') dep_inferred_count=0
[verify] sample atomic_key=('<e_1097>', '<r_115>') dep_inferred_count=41
[verify] sample atomic_key=('<e_233>', '<r_96>') dep_inferred_count=32
[verify] sample atomic_key=('<e_905>', '<r_179>') dep_inferred_count=39
[verify] sample atomic_key=('<e_1674>', '<r_8>') dep_inferred_count=37
[verify] sample atomic_key=('<e_1768>', '<r_177>') 

### Using the following code to finetuning a model
- ```run_name="TRM_finetune" && LOAD_CKPT={Your TRM checkpoint path here} && CUDA_VISIBLE_DEVICES=0 torchrun --nproc-per-node=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain_grok_evaluate_ver_0_1.py arch=trm data_paths="[{Your finetune datapath here}]" evaluators="[]" +load_checkpoint="${LOAD_CKPT}" epochs=20000 eval_interval=50 lr=2e-5 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 arch.mlp_t=True arch.pos_encodings=None arch.L_layers=2 arch.H_cycles=2 arch.L_cycles=6 arch.halt_max_steps=1 arch.hidden_size=1536 ema=True global_batch_size=512 +max_inference_steps=1 checkpoint_every_eval=False +causal=False +run_name="${run_name}" +post_fix="2000.200.18.0_no_aug_finetuning"``` (Keep the archtecture parameters the same to your pretrain checkpoint)