In [None]:
# load_and_sample.py
# Robust checkpoint loader for TinyGPT checkpoints produced by tiny_gpt_toy_rl.py

import sys, types, math, random
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F

# ---- 1) Define TinyGPT exactly as in training (forward returns (logits, hidden)) ----
class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd, n_head, block_size, attn_pdrop=0.0, resid_pdrop=0.0):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.key = nn.Linear(n_embd, n_embd)
        self.query = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)
        self.proj = nn.Linear(n_embd, n_embd)
        self.attn_drop = nn.Dropout(attn_pdrop)
        self.resid_drop = nn.Dropout(resid_pdrop)
        self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1,1,block_size,block_size))

    def forward(self, x):
        B, T, C = x.size()
        k = self.key(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2)
        q = self.query(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2)
        v = self.value(x).view(B, T, self.n_head, C//self.n_head).transpose(1,2)
        att = (q @ k.transpose(-2,-1)) / math.sqrt(k.size(-1))
        att = att.masked_fill(self.mask[:,:,:T,:T]==0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)
        y = att @ v
        y = y.transpose(1,2).contiguous().view(B, T, C)
        y = self.resid_drop(self.proj(y))
        return y

class Block(nn.Module):
    def __init__(self, n_embd, n_head, block_size, resid_pdrop=0.0, attn_pdrop=0.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        self.attn = CausalSelfAttention(n_embd, n_head, block_size, attn_pdrop, resid_pdrop)
        self.mlp = nn.Sequential(
            nn.Linear(n_embd, 4*n_embd),
            nn.GELU(),
            nn.Linear(4*n_embd, n_embd),
            nn.Dropout(resid_pdrop),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class TinyGPT(nn.Module):
    def __init__(self, vocab_size, block_size, n_layer=4, n_head=4, n_embd=256,
                 emb_pdrop=0.0, resid_pdrop=0.0, attn_pdrop=0.0):
        super().__init__()
        self.block_size = block_size
        self.vocab_size = vocab_size
        self.tok_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb = nn.Embedding(block_size, n_embd)
        self.drop = nn.Dropout(emb_pdrop)
        self.blocks = nn.Sequential(*[
            Block(n_embd, n_head, block_size, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop)
            for _ in range(n_layer)
        ])
        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size, bias=False)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Embedding)):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight); nn.init.zeros_(m.bias)

    def forward(self, idx):
        B, T = idx.size()
        assert T <= self.block_size
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
        tok = self.tok_emb(idx)
        pos = self.pos_emb(pos)[None, :, :]
        x = self.drop(tok + pos)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x)
        return logits, x

# ---- 2) Stub a Config so torch.load can unpickle checkpoints that saved the dataclass ----
# We don't need fields to reconstruct the object; we only need a class with the right name in __main__.
class Config:  # minimal stub; pickle will just set attributes onto it
    pass

# Make sure __main__.Config exists (important in notebooks/REPL)
import types
main_mod = sys.modules.get("__main__")
if main_mod is None:
    main_mod = types.ModuleType("__main__")
    sys.modules["__main__"] = main_mod
setattr(main_mod, "Config", Config)

# ---- 3) Loader that also handles optional q_head ----
def load_checkpoint(checkpoint_path, device="cpu"):
    ckpt = torch.load(checkpoint_path, map_location=device)
    cfg = ckpt["config"]      # this is now unpickled into our stub Config
    vocab_size = ckpt["vocab_size"]
    stoi = ckpt["stoi"]
    itos = ckpt["itos"]

    # Build base model
    model = TinyGPT(
        vocab_size=vocab_size,
        block_size=cfg.block_size,
        n_layer=cfg.n_layer,
        n_head=cfg.n_head,
        n_embd=cfg.n_embd,
        emb_pdrop=0.0, resid_pdrop=0.0, attn_pdrop=0.0
    ).to(device)

    # If training used a separate Q head, attach it before loading weights
    use_q_head = getattr(cfg, "use_q_head", False)
    if use_q_head:
        model.q_head = nn.Linear(cfg.n_embd, vocab_size, bias=True).to(device)

    # Now load weights
    missing, unexpected = model.load_state_dict(ckpt["model_state_dict"], strict=False)
    if use_q_head and ("q_head.weight" in missing or "q_head.bias" in missing):
        raise RuntimeError("Checkpoint expects q_head, but it was not loaded correctly.")
    if unexpected:
        print(f"[warn] Unexpected keys in state_dict: {unexpected}")
    model.eval()
    return model, stoi, itos, cfg

# ---- 4) Sampling (fix: unpack (logits, _) correctly) ----
def generate_continuation(model, prompt, stoi, itos, max_new_tokens=150, device="cpu", temperature=1.0):
    model.eval()
    # encode prompt (unknown chars -> 0 safely)
    x = torch.tensor([[stoi.get(c, 0) for c in prompt]], dtype=torch.long, device=device)
    out = list(prompt)
    with torch.no_grad():
        for _ in range(max_new_tokens):
            x_cond = x if x.size(1) <= model.block_size else x[:, -model.block_size:]
            logits, _ = model(x_cond)                      # <-- unpack
            logits = logits[:, -1, :] / float(temperature)
            probs = F.softmax(logits, dim=-1)
            ix = torch.multinomial(probs, num_samples=1)
            ch = itos[ix.item()]
            out.append(ch)
            x = torch.cat([x, ix], dim=1)
            # optional early stop: closes all parens in the full output so far
            if ch == ')' and out.count('(') == out.count(')'):
                break
    return "".join(out)

# ---- 5) Validity helpers (yours) ----
def check_validity(sequence: str) -> bool:
    depth = 0
    for ch in sequence:
        if ch == '(':
            depth += 1
        elif ch == ')':
            depth -= 1
            if depth < 0:
                return False
    return depth == 0

def is_dyck_parens(s: str) -> bool:
    depth = 0
    for ch in s:
        if ch == '(':
            depth += 1
        elif ch == ')':
            depth -= 1
            if depth < 0:
                return False
    return depth == 0

# # ---- 6) Demo ----
# if __name__ == "__main__":



In [4]:
checkpoint_path = "/nfs/shuozhe/clean_pretrain/tiny_llm_checkpoints/checkpoint_iter_3600.pt"
device = "cuda" if torch.cuda.is_available() else "cpu"
model, stoi, itos, cfg = load_checkpoint(checkpoint_path, device=device)

user_input = "((((ab0("
out = generate_continuation(model, user_input, stoi, itos, max_new_tokens=150, device=device, temperature=1.0)
print(f"Input:    {user_input}")
print(f"Output:   {out}")
print(f"Validity: {check_validity(out)}")

  checkpoint = torch.load(checkpoint_path, map_location=device)


AttributeError: Can't get attribute 'Config' on <module '__main__'>

In [85]:
#!/usr/bin/env python3
"""
Validate Dyck-1 (parentheses) in a string while ignoring all other characters.

Definition: Keep only '(' and ')' from the input. The string is valid iff
the resulting sequence is a well-formed Dyck word: the running depth never
goes negative and ends at zero.
"""

from typing import Tuple, Optional

def is_dyck_parens(s: str) -> bool:
    """
    Return True iff the parentheses subsequence of `s` is a valid Dyck word.

    Examples:
        is_dyck_parens("a(b)c")          -> True   (paren subseq: "()")
        is_dyck_parens("ab)cd(")         -> False  (")(" -> invalid)
        is_dyck_parens("((x) y) z")      -> True   ("(())")
        is_dyck_parens("(()")            -> False  ("(()" -> depth ends 1)
        is_dyck_parens("no parens")      -> True   (empty subseq is valid)
    """
    depth = 0
    for ch in s:
        if ch == '(':
            depth += 1
        elif ch == ')':
            depth -= 1
            if depth < 0:
                return False
    return depth == 0


# Optional: a debugging variant that also reports the first error position.
def check_dyck_parens(s: str) -> Tuple[bool, Optional[int], int]:
    """
    Returns (valid, first_error_index, final_depth).

    - valid: True iff the parentheses subsequence is Dyck-valid.
    - first_error_index: index in the *original string* of the first place
      where validity is violated (None if no violation).
    - final_depth: depth after consuming the whole string (0 iff balanced).
    """
    depth = 0
    for i, ch in enumerate(s):
        if ch == '(':
            depth += 1
        elif ch == ')':
            depth -= 1
            if depth < 0:
                return False, i, depth
    # If depth != 0, it's unbalanced (too many '('). No single error index,
    # but caller can infer imbalance from final_depth.
    return depth == 0, None if depth == 0 else None, depth


if __name__ == "__main__":
    tests = {
        "": True,
        "abc": True,
        "()": True,
        "(a)b": True,
        "a(b(c)d)e": True,
        "((x)y)z": True,
        "ab)cd(": False,
        "(()": False,
        "())(": False,
        "x)(": False,
        "(((foo)))bar": True,
        "foo(bar(baz)qux))": False,
    }
    for s, expected in tests.items():
        got = is_dyck_parens(s)
        print(f"{s!r:28} -> {got} (expected {expected})")
        assert got == expected, f"Mismatch for {s!r}"
    print("All tests passed.")


''                           -> True (expected True)
'abc'                        -> True (expected True)
'()'                         -> True (expected True)
'(a)b'                       -> True (expected True)
'a(b(c)d)e'                  -> True (expected True)
'((x)y)z'                    -> True (expected True)
'ab)cd('                     -> False (expected False)
'(()'                        -> False (expected False)
'())('                       -> False (expected False)
'x)('                        -> False (expected False)
'(((foo)))bar'               -> True (expected True)
'foo(bar(baz)qux))'          -> False (expected False)
All tests passed.


In [76]:
def make_paren_code(n_sequences=50000, min_len=20, max_len=200, emit_key=False):
    """Paren-code language with 2 hidden keys (0/1) that bias parity of opens vs closes.
    If emit_key=True, each line is prefixed with '0:' or '1:' so the key is observable.
    Sequences are exact-length (even L >= 2) and balanced.
    """
    def gen_one(L):
        if L % 2 == 1 or L < 2:
            L = max(2, L + (L % 2))
        key = random.choice('01')
        s = []
        depth = 0
        for t in range(L):
            rem = L - t
            if depth == rem:
                s.append(')'); depth -= 1
                continue
            if depth == 0:
                s.append('('); depth += 1
                continue
            even = (depth % 2 == 0)
            # parity preference flipped by key
            ch = '(' if (key == '0' and even) or (key == '1' and not even) else ')'
            # feasibility guardrails
            if ch == '(' and depth + 1 > rem - 1:
                ch = ')'
            if ch == ')' and depth == 0:
                ch = '('
            s.append(ch)
            depth += 1 if ch == '(' else -1
        assert depth == 0 and len(s) == L
        body = "".join(s)
        return (key + ":" + body) if emit_key else body
    return "\n".join(gen_one(random.randint(min_len, max_len)) for _ in range(n_sequences))

def make_dyck1_greedy(n_sequences=50000, min_len=20, max_len=200):
    """Greedy exact-length baseline: emit k='(' then k=')' for exact-length L (L even).
    This produces highly-structured baseline sequences like '(((...)))'.
    """
    def gen_one(L):
        if L % 2 == 1 or L < 2:
            L = max(2, L + (L % 2))
        k = L // 2
        return "(" * k + ")" * k
    seqs = []
    for _ in range(n_sequences):
        L = random.randint(min_len, max_len)
        seqs.append(gen_one(L))
    return "\n".join(seqs)

def make_dyck1_random(n_sequences=50000, min_len=20, max_len=200):
    """Exact-length Dyck-1 sequences. Ensures each returned sequence has exact length L (even),
    never drops below depth 0 and ends at depth 0.
    Uses a simple random feasible-walk sampler (not perfectly uniform over all Dyck words of length L,
    but avoids length drift and infeasible moves).
    """
    def gen_one(L):
        # ensure even length >= 2
        if L % 2 == 1 or L < 2:
            L = max(2, L + (L % 2))
        s = []
        depth = 0
        for t in range(L):
            rem = L - t
            # if we must close to reach zero in remaining steps
            if depth == rem:
                s.append(')')
                depth -= 1
                continue
            if depth == 0:
                s.append('(')
                depth += 1
                continue
            # both moves feasible; avoid opening when not enough remaining steps to close later
            can_open = (depth + 1) <= (rem - 1)
            if can_open and random.random() < 0.5:
                s.append('(')
                depth += 1
            else:
                s.append(')')
                depth -= 1
        assert depth == 0 and len(s) == L
        return "".join(s)
    return "\n".join(gen_one(random.randint(min_len, max_len)) for _ in range(n_sequences))

def make_copy_dataset(n_sequences=50000, min_len=5, max_len=40, alphabet="abcd"):
    seqs = []
    for _ in range(n_sequences):
        L = random.randint(min_len, max_len)
        x = "".join(random.choice(alphabet) for _ in range(L))
        seqs.append("<s>" + x + "#" + x + "</s>")
    return "\n".join(seqs)

xx = make_dyck1_random(n_sequences=10, min_len=10, max_len=20)
print(xx)


(((()())))(())((()))
(()(()()()()(())))
()()(()((()())))()
()((())())(((())))
()()((()()))()()()
(()(((()))()(()())))
()(())(((())))
(())(((((())))))
(())((()))
()()(()())((()))


In [13]:
def main():
    # Configuration
    checkpoint_path = "/nfs/shuozhe/clean_pretrain/checkpoints/checkpoint_iter_2000.pt"
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    print("="*80)
    print("TINYGPT CHECKPOINT TESTER")
    print("="*80)
    print()
    
    # Load model
    model, stoi, itos, cfg = load_checkpoint(checkpoint_path, device=device)
    
    while True:
        print("\n" + "="*80)
        print("OPTIONS:")
        print("  1. Enter custom input")
        print("  2. Use dyck_random example")
        print("  3. Quit")
        print("="*80)
        
        choice = input("\nYour choice (1/2/3): ").strip()
        
        if choice == '3':
            print("Exiting...")
            break
        
        elif choice == '1':
            # Custom input
            print("\n--- Custom Input Mode ---")
            user_input = input("Enter your input prompt: ").strip()
            
            if not user_input:
                print("Empty input, skipping...")
                continue
            
            print(f"\n{'='*80}")
            print(f"INPUT: {user_input}")
            print(f"{'='*80}")
            
            # Generate
            temperature = input("Temperature (default 1.0, press Enter to use default): ").strip()
            temp = float(temperature) if temperature else 1.0
            
            output = generate_continuation(
                model, user_input, stoi, itos,
                max_new_tokens=150,
                device=device,
                temperature=temp
            )
            
            print(f"\n{'='*80}")
            print(f"MODEL OUTPUT:")
            print(f"{'='*80}")
            print(output)
            print(f"\n{'='*80}")
            print(f"Analysis:")
            print(f"  Length: {len(output)}")
            print(f"  Valid (balanced parens): {check_validity(output)}")
            print(f"{'='*80}")
        
        elif choice == '2':
            # Dyck random example
            print("\n--- Dyck Random Mode ---")
            
            # Generate a dyck_random sequence
            dyck_samples = make_dyck1_random(n_sequences=1, min_len=30, max_len=80)
            full_sequence = dyck_samples.strip()
            
            # Use first few characters as prompt
            prompt_len = min(5, len(full_sequence))
            prompt = full_sequence[:prompt_len]
            
            print(f"\n{'='*80}")
            print(f"GROUND TRUTH (full dyck_random sequence):")
            print(f"{'='*80}")
            print(full_sequence)
            print(f"\nLength: {len(full_sequence)}, Valid: {check_validity(full_sequence)}")
            
            print(f"\n{'='*80}")
            print(f"INPUT (first {prompt_len} characters as prompt):")
            print(f"{'='*80}")
            print(prompt)
            
            # Generate continuation
            temperature = input("\nTemperature (default 1.0, press Enter to use default): ").strip()
            temp = float(temperature) if temperature else 1.0
            
            output = generate_continuation(
                model, prompt, stoi, itos,
                max_new_tokens=len(full_sequence),
                device=device,
                temperature=temp
            )
            
            print(f"\n{'='*80}")
            print(f"MODEL OUTPUT:")
            print(f"{'='*80}")
            print(output)
            
            print(f"\n{'='*80}")
            print(f"COMPARISON:")
            print(f"{'='*80}")
            print(f"Ground truth length: {len(full_sequence)}")
            print(f"Generated length:    {len(output)}")
            print(f"Ground truth valid:  {check_validity(full_sequence)}")
            print(f"Generated valid:     {check_validity(output)}")
            print(f"{'='*80}")
        
        else:
            print("Invalid choice, please try again.")
        
        # Ask if user wants to continue
        cont = input("\nTest another input? (y/n): ").strip().lower()
        if cont != 'y':
            print("Exiting...")
            break

if __name__ == "__main__":
    main()

  checkpoint = torch.load(checkpoint_path, map_location=device)


TINYGPT CHECKPOINT TESTER

Loading checkpoint from: /nfs/shuozhe/clean_pretrain/checkpoints/checkpoint_iter_2000.pt
✓ Successfully loaded checkpoint from iteration 2000
✓ Model config: vocab_size=5, block_size=256
  n_layer=4, n_head=4, n_embd=256
✓ Training objective: acc_pg


OPTIONS:
  1. Enter custom input
  2. Use dyck_random example
  3. Quit

--- Custom Input Mode ---

INPUT: ((

MODEL OUTPUT:
(()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()()(

Analysis:
  Length: 152
  Valid (balanced parens): False
Exiting...
